Skip to content

Commit 105c432

Browse files
Investigate refactoring opportunities for batch management in Plugin and Compiler - validateModelBatch conditions
1 parent b9cbb0f commit 105c432

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/plugins/intel_npu/src/plugin/src/plugin.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,8 @@ bool validateModelBatch(const std::shared_ptr<const ov::Model>& model, Logger lo
541541
ov::Layout layout = ov::layout::get_layout(input);
542542

543543
// Batching on plugin is working only when batching is found on 0th dimension
544-
if ((shape.size() && shape[0].get_max_length() > 1) ||
545-
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == 0)) {
544+
if ((shape.size() && shape[intel_npu::utils::BATCH_AXIS].get_max_length() != intel_npu::utils::DEFAULT_BATCH_SIZE) ||
545+
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == intel_npu::utils::BATCH_AXIS)) {
546546
const auto& staticShape = shape.is_dynamic() ? shape.get_max_shape() : input->get_shape();
547547
batchedInputs.insert(params[input_id]->output(0));
548548

@@ -579,8 +579,8 @@ bool validateModelBatch(const std::shared_ptr<const ov::Model>& model, Logger lo
579579
ov::Layout layout = ov::layout::get_layout(output);
580580

581581
// Batching on plugin is working only when batching is found on 0th dimension
582-
if ((shape.size() && shape[0].get_max_length() > 1) ||
583-
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == 0)) {
582+
if ((shape.size() && shape[intel_npu::utils::BATCH_AXIS].get_max_length() != intel_npu::utils::DEFAULT_BATCH_SIZE) ||
583+
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == intel_npu::utils::BATCH_AXIS)) {
584584
const auto& node = output->input_value(0);
585585
const auto& staticShape = shape.is_dynamic() ? shape.get_max_shape() : output->get_shape();
586586
batchedOutputs.insert(ov::Output<const ov::Node>(node.get_node(), node.get_index()));

0 commit comments

Comments
 (0)