Skip to content

Commit b9cbb0f

Browse files
Investigate refactoring opportunities for batch management in Plugin and Compiler - fix BA issues - treat every model with batch 1 as a potentially dynamically batched one
1 parent 7b6f81a commit b9cbb0f

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

src/plugins/intel_npu/src/plugin/src/plugin.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,17 @@ std::shared_ptr<ov::Model> create_dummy_model(const std::vector<IODescriptor>& i
6363
continue;
6464
}
6565

66+
auto shape = inputDescriptor.shapeFromIRModel.has_value() ? *inputDescriptor.shapeFromIRModel
67+
: inputDescriptor.shapeFromCompiler;
68+
// Treat every model with batch 1 as a potentially dynamically batched one.
69+
// TODO: should we protect this part with a certain condition?
70+
if (shape[intel_npu::utils::BATCH_AXIS] == intel_npu::utils::DEFAULT_BATCH_SIZE) {
71+
shape[intel_npu::utils::BATCH_AXIS] = ov::Dimension(-1);
72+
}
73+
6674
std::shared_ptr<ov::op::v0::Parameter> parameter = std::make_shared<ov::op::v0::Parameter>(
6775
inputDescriptor.precision,
68-
inputDescriptor.shapeFromIRModel.has_value() ? *inputDescriptor.shapeFromIRModel
69-
: inputDescriptor.shapeFromCompiler);
76+
shape);
7077

7178
parameter->set_friendly_name(inputDescriptor.nodeFriendlyName);
7279
parameter->output(0).get_tensor().set_names(inputDescriptor.outputTensorNames);
@@ -86,10 +93,16 @@ std::shared_ptr<ov::Model> create_dummy_model(const std::vector<IODescriptor>& i
8693
std::shared_ptr<ov::Node> constantDummy =
8794
std::make_shared<ov::op::v0::Constant>(outputDescriptor.precision, CONSTANT_NODE_DUMMY_SHAPE);
8895

96+
auto shape = outputDescriptor.shapeFromIRModel.has_value() ? *outputDescriptor.shapeFromIRModel
97+
: outputDescriptor.shapeFromCompiler;
98+
// Treat every model with batch 1 as a potentially dynamically batched one.
99+
if (shape[intel_npu::utils::BATCH_AXIS] == intel_npu::utils::DEFAULT_BATCH_SIZE) {
100+
shape[intel_npu::utils::BATCH_AXIS] = ov::Dimension(-1);
101+
}
102+
89103
const std::shared_ptr<ov::descriptor::Tensor>& tensorDummy = std::make_shared<ov::descriptor::Tensor>(
90104
outputDescriptor.precision,
91-
outputDescriptor.shapeFromIRModel.has_value() ? *outputDescriptor.shapeFromIRModel
92-
: outputDescriptor.shapeFromCompiler,
105+
shape,
93106
outputDescriptor.outputTensorNames);
94107

95108
auto& result = results.emplace_back(std::make_shared<ov::op::v0::Result>(constantDummy));
@@ -684,17 +697,22 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
684697
if (localConfig.isAvailable(ov::intel_npu::batch_mode.name())) {
685698
bool autoOrPluginBatch = localConfig.get<BATCH_MODE>() == ov::intel_npu::BatchMode::PLUGIN ||
686699
localConfig.get<BATCH_MODE>() == ov::intel_npu::BatchMode::AUTO;
687-
bool pluginBatchingIsSupported = validateModelBatch(modelForCompilation, _logger);
688-
if (autoOrPluginBatch && pluginBatchingIsSupported) {
689-
try {
700+
try {
701+
const bool pluginBatchingIsSupported = validateModelBatch(modelForCompilation, _logger);
702+
const bool batchedModel = ov::get_batch(modelForCompilation) != intel_npu::utils::DEFAULT_BATCH_SIZE;
703+
704+
if (autoOrPluginBatch && pluginBatchingIsSupported && batchedModel) {
690705
_logger.info("Attempting to handle batching on the plugin side.");
691706
ov::set_batch(modelForCompilation, 1);
692-
} catch (const std::exception& ex) {
693-
_logger.info("Couldn't reshape the model. Batching will be handed by compiler.", ex.what());
707+
// TODO: add debatcher for more complicated cases as set_batch is pretty naive.
708+
} else {
709+
_logger.info("Unable to manage batching on the plugin side, so the compiler will take care of it.");
694710
}
711+
712+
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
713+
} catch (const std::exception& ex) {
714+
_logger.info("Couldn't validate and reshape the model. Batching will be handed by compiler.", ex.what());
695715
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
696-
} else {
697-
_logger.info("Unable to manage batching on the plugin side, so the compiler will take care of it.");
698716
}
699717
}
700718

0 commit comments

Comments
 (0)