Skip to content

Commit dfac8c5

Browse files
Investigate refactoring opportunities for batch management in Plugin and Compiler - no metadata changes
1 parent 1549a79 commit dfac8c5

File tree

2 files changed

+132
-19
lines changed

2 files changed

+132
-19
lines changed

src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
8181
return std::nullopt;
8282
}
8383

84-
if (!desc.shapeFromIRModel.has_value() || !desc.shapeFromIRModel.value().is_dynamic()) {
85-
return std::nullopt;
86-
}
87-
8884
if (batchSize.has_value()) {
8985
return batchSize.value();
9086
}
@@ -93,11 +89,7 @@ std::optional<size_t> determine_dynamic_batch_size(const IODescriptor& desc,
9389
return std::nullopt;
9490
}
9591

96-
if ((*desc.shapeFromIRModel)[intel_npu::utils::BATCH_AXIS].is_dynamic()) {
97-
return tensor->get_shape()[intel_npu::utils::BATCH_AXIS];
98-
}
99-
100-
return std::nullopt;
92+
return tensor->get_shape()[intel_npu::utils::BATCH_AXIS];
10193
}
10294

10395
} // namespace
@@ -975,8 +967,8 @@ void ZeroInferRequest::infer_async() {
975967
copied_bytes_from_user,
976968
get_level_zero_input(inputIndex)->get_byte_size());
977969
}
978-
OPENVINO_ASSERT(get_level_zero_input(inputIndex)->get_byte_size() == copied_bytes_from_user,
979-
"Bytes copied must be equal");
970+
// OPENVINO_ASSERT(get_level_zero_input(inputIndex)->get_byte_size() == copied_bytes_from_user,
971+
// "Bytes copied must be equal");
980972
}
981973

982974
++inputIndex;

src/plugins/intel_npu/src/plugin/src/plugin.cpp

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -516,9 +516,110 @@ ov::Any Plugin::get_property(const std::string& name, const ov::AnyMap& argument
516516
return _properties->get_property(name, arguments);
517517
}
518518

519+
bool validateModelBatch(const std::shared_ptr<const ov::Model>& model, Logger logger) {
520+
std::set<ov::Output<const ov::Node>> batchedInputs;
521+
std::set<ov::Output<const ov::Node>> batchedOutputs;
522+
std::set<size_t> sBatchSize;
523+
524+
const auto& params = model->get_parameters();
525+
for (size_t input_id = 0; input_id < params.size(); input_id++) {
526+
const auto& input = params[input_id];
527+
const auto& shape = input->get_partial_shape();
528+
ov::Layout layout = ov::layout::get_layout(input);
529+
530+
// Batching on plugin is working only when batching is found on 0th dimension
531+
if ((shape.size() && shape[0].get_max_length() > 1) ||
532+
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == 0)) {
533+
const auto& staticShape = shape.is_dynamic() ? shape.get_max_shape() : input->get_shape();
534+
batchedInputs.insert(params[input_id]->output(0));
535+
536+
if (shape.rank().is_dynamic()) {
537+
OPENVINO_THROW("Shapes with dynamic rank are not supported.");
538+
} else {
539+
sBatchSize.insert(staticShape[0]);
540+
}
541+
} else {
542+
// gather some diagnostic info
543+
std::optional<size_t> batch_dim_index_detected;
544+
for (size_t i = 1; i < shape.size(); i++) {
545+
if (shape[i].has_symbol()) {
546+
batch_dim_index_detected = i;
547+
break;
548+
}
549+
}
550+
std::stringstream sstream;
551+
sstream << "Only networks with inputs batched by 0th dimension are supported. ";
552+
if (batch_dim_index_detected.has_value()) {
553+
sstream << "The batch has been detected on: " << batch_dim_index_detected.value()
554+
<< " dimension instead. ";
555+
} else {
556+
sstream << "The batch hasn't been detected at all. ";
557+
}
558+
sstream << "Please check input id: " << input_id << " by the name: " << input->get_friendly_name()
559+
<< ", layout: " << layout.to_string() << ", is_dynamic: " << shape.is_dynamic();
560+
logger.info("%s", sstream.str());
561+
return false;
562+
}
563+
}
564+
for (const auto& output : model->get_results()) {
565+
const auto& shape = output->get_output_partial_shape(0);
566+
ov::Layout layout = ov::layout::get_layout(output);
567+
568+
// Batching on plugin is working only when batching is found on 0th dimension
569+
if ((shape.size() && shape[0].get_max_length() > 1) ||
570+
(ov::layout::has_batch(layout) && ov::layout::batch_idx(layout) == 0)) {
571+
const auto& node = output->input_value(0);
572+
const auto& staticShape = shape.is_dynamic() ? shape.get_max_shape() : output->get_shape();
573+
batchedOutputs.insert(ov::Output<const ov::Node>(node.get_node(), node.get_index()));
574+
575+
if (shape.rank().is_dynamic()) {
576+
OPENVINO_THROW("Shapes with dynamic rank are not supported.");
577+
} else {
578+
sBatchSize.insert(staticShape[0]);
579+
}
580+
} else {
581+
logger.info("Only networks with outputs batched by 0th dimension are supported. Please check an output by "
582+
"the name: %s, layout: %s",
583+
output->get_friendly_name(),
584+
layout.to_string());
585+
return false;
586+
}
587+
}
588+
if (!batchedInputs.size() || !batchedOutputs.size()) {
589+
logger.info(
590+
"Only networks with inputs/outputs featuring batched dim are supported! Got inputs: %ld, outputs: %ld",
591+
batchedInputs.size(),
592+
batchedOutputs.size());
593+
return false;
594+
}
595+
596+
if (sBatchSize.size() != 1) {
597+
logger.info("Batching size shall have same value for all tensors! Got unique batch sizes number: %ld",
598+
sBatchSize.size());
599+
return false;
600+
}
601+
602+
auto node_info_printer = [&logger](const auto& ov_node, std::string nodeType) {
603+
logger.info("%s: %s has shape value: %s",
604+
nodeType,
605+
ov_node.get_any_name(),
606+
ov_node.get_partial_shape().to_string());
607+
};
608+
609+
for (const auto& ov_node : batchedInputs) {
610+
node_info_printer(ov_node, "Input");
611+
}
612+
for (const auto& ov_node : batchedOutputs) {
613+
node_info_printer(ov_node, "Output");
614+
}
615+
616+
return true;
617+
}
618+
519619
std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<const ov::Model>& model,
520620
const ov::AnyMap& properties) const {
521621
OV_ITT_SCOPED_TASK(itt::domains::NPUPlugin, "Plugin::compile_model");
622+
auto modelForCompilation = model->clone();
522623

523624
// Before going any further: if
524625
// ... 1 - NPUW mode is activated
@@ -560,21 +661,41 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
560661
auto device = _backend == nullptr ? nullptr : _backend->getDevice(localConfig.get<DEVICE_ID>());
561662
localConfig.update({{ov::intel_npu::platform.name(), platform}});
562663

563-
if (localConfig.isAvailable(ov::intel_npu::batch_mode.name()) &&
564-
!localConfig.has(ov::intel_npu::batch_mode.name())) {
664+
auto updateBatchMode = [&](ov::intel_npu::BatchMode mode) {
565665
std::stringstream strStream;
566-
strStream << ov::intel_npu::BatchMode::AUTO;
666+
strStream << mode;
667+
_logger.info("Setting batching mode to %s.", strStream.str());
567668
localConfig.update({{ov::intel_npu::batch_mode.name(), strStream.str()}});
669+
};
670+
671+
if (localConfig.isAvailable(ov::intel_npu::batch_mode.name()) &&
672+
!localConfig.has(ov::intel_npu::batch_mode.name())) {
673+
updateBatchMode(ov::intel_npu::BatchMode::AUTO);
568674
}
569675

570676
if (localConfig.isAvailable(ov::intel_npu::batch_mode.name()) && !model->get_variables().empty()) {
571677
if (localConfig.get<BATCH_MODE>() == ov::intel_npu::BatchMode::PLUGIN) {
572678
OPENVINO_THROW("This model contains states, thus it is not supported when handling batching on the plugin");
573679
}
574680

575-
std::stringstream strStream;
576-
strStream << ov::intel_npu::BatchMode::COMPILER;
577-
localConfig.update({{ov::intel_npu::batch_mode.name(), strStream.str()}});
681+
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
682+
}
683+
684+
if (localConfig.isAvailable(ov::intel_npu::batch_mode.name())) {
685+
bool autoOrPluginBatch = localConfig.get<BATCH_MODE>() == ov::intel_npu::BatchMode::PLUGIN ||
686+
localConfig.get<BATCH_MODE>() == ov::intel_npu::BatchMode::AUTO;
687+
bool pluginBatchingIsSupported = validateModelBatch(modelForCompilation, _logger);
688+
if (autoOrPluginBatch && pluginBatchingIsSupported) {
689+
try {
690+
_logger.info("Attempting to handle batching on the plugin side.");
691+
ov::set_batch(modelForCompilation, 1);
692+
} catch (const std::exception& ex) {
693+
_logger.info("Couldn't reshape the model. Batching will be handed by compiler.", ex.what());
694+
}
695+
updateBatchMode(ov::intel_npu::BatchMode::COMPILER);
696+
} else {
697+
_logger.info("Unable to manage batching on the plugin side, so the compiler will take care of it.");
698+
}
578699
}
579700

580701
// Update stepping w/ information from driver, unless provided by user or we are off-device
@@ -625,10 +746,10 @@ std::shared_ptr<ov::ICompiledModel> Plugin::compile_model(const std::shared_ptr<
625746
_logger.debug("performing compile");
626747

627748
if (!localConfig.get<WEIGHTLESS_BLOB>()) {
628-
graph = compiler->compile(model->clone(), localConfig);
749+
graph = compiler->compile(modelForCompilation->clone(), localConfig);
629750
} else {
630751
check_weightless_cache_attribute_occurrence(model);
631-
graph = compiler->compileWS(model->clone(), localConfig);
752+
graph = compiler->compileWS(modelForCompilation->clone(), localConfig);
632753
}
633754
} catch (const std::exception& ex) {
634755
OPENVINO_THROW(ex.what());

0 commit comments

Comments
 (0)