diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 59979189eed0f..9c42bf34b5b0f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1582,11 +1582,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForInputs( auto num_inputs = GetInputCount(); std::vector mem_infos; - mem_infos.resize(num_inputs); + if (num_inputs > 0) { + mem_infos.resize(num_inputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_inputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_inputs)); + } return mem_infos; } @@ -1598,11 +1600,13 @@ inline std::vector ConstSessionImpl::GetMemoryInfoForOutputs auto num_outputs = GetOutputCount(); std::vector mem_infos; - mem_infos.resize(num_outputs); + if (num_outputs > 0) { + mem_infos.resize(num_outputs); - ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, - reinterpret_cast(mem_infos.data()), - num_outputs)); + ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_, + reinterpret_cast(mem_infos.data()), + num_outputs)); + } return mem_infos; } @@ -1631,12 +1635,12 @@ template inline std::vector ConstSessionImpl::GetEpDeviceForInputs() const { auto num_inputs = GetInputCount(); std::vector input_devices; - input_devices.resize(num_inputs); - - ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, - reinterpret_cast(input_devices.data()), - num_inputs)); - + if (num_inputs > 0) { + input_devices.resize(num_inputs); + ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_, + reinterpret_cast(input_devices.data()), + num_inputs)); + } return input_devices; } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 17b7f9af372bc..c424bc4264b0d 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType for (const auto* def : def_list) { InlinedVector node_info_vec; + Status status; if (type == SessionInputOutputType::kOutput) { - ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec); } else { - ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); } - // all entries are for the same OrtDevice so use the first one. - // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice - // from the session state and use its OrtMemoryInfo. - auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); - memory_info.push_back(&allocator->Info()); + if (!status.IsOK()) { + if (type == SessionInputOutputType::kInput) { + return status; + } + + // Check first if this output is produced by an input that directly + // propagates to output with the same name. + status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec); + if (status.IsOK()) { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } else { + // Check if this output is produced by a constant initializer + // Pick the MemoryInfo from the initializer's OrtValue + const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap(); + + OrtValueIndex ort_value_index; + status = ort_value_map.GetIdx(def->Name(), ort_value_index); + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + + const auto& idx_to_ort_value = session_state_->GetInitializedTensors(); + auto it = idx_to_ort_value.find(ort_value_index); + if (it == idx_to_ort_value.end()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to find node output or a constant initializer producing output: ", + def->Name(), "."); + } + const auto& tensor = it->second.Get(); + auto allocator = session_state_->GetAllocator(tensor.Location()); + memory_info.push_back(&allocator->Info()); + } + } else { + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } } return Status::OK(); @@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector node_info_vec; ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); - - // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map - // instead of doing a linear search each time. - const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); - auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { - return entry->ep_name == ep_name; - }); - - ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + assert(!node_info_vec.empty()); + // If we have an input that is not consumed by any node, + // including nodes in subgraphs, then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } } return Status::OK(); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b7a9da8e1b658..8c2928670934a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders, CApiTestWithProvider, ::testing::Values(0, 1, 2, 3, 4)); +TEST(CApiTest, TestInputPassThroughToOutput) { + const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(1U, inputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(1U, inputs_epdevices.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(7U, outputs_meminfos.size()); +} + +TEST(CApiTest, TestDanglingInput) { + // Here we test an issue with segments_ids that is an input not consumed by anything + // This kind of model is unlikely to be used in practice but we want to make sure it works + const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx"); + Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_EQ(2U, inputs_meminfos.size()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(2U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_EQ(2U, inputs_epdevices.size()); + // One of the devices returning is null since the input is not consumed + // there is not a device for it. + const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(), + [](const auto& device) { return device == nullptr; }); + ASSERT_TRUE(null_present); +} + #if !defined(DISABLE_SPARSE_TENSORS) TEST(CApiTest, SparseOutputModel) { std::vector dense_shape{3, 3}; @@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) { std::vector ort_inputs; std::vector input_names; const char* const output_names[] = {"values"}; + // This model produces a sparse output from a constant sparse initializer Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{}); + auto inputs_meminfos = session.GetMemoryInfoForInputs(); + ASSERT_TRUE(inputs_meminfos.empty()); + auto outputs_meminfos = session.GetMemoryInfoForOutputs(); + ASSERT_EQ(1U, outputs_meminfos.size()); + auto inputs_epdevices = session.GetEpDeviceForInputs(); + ASSERT_TRUE(inputs_epdevices.empty()); + auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(), output_names, 1); ASSERT_EQ(ort_outputs.size(), 1U); diff --git a/onnxruntime/test/testdata/input_propagated_to_output.onnx b/onnxruntime/test/testdata/input_propagated_to_output.onnx new file mode 100644 index 0000000000000..feeab10556cb0 Binary files /dev/null and b/onnxruntime/test/testdata/input_propagated_to_output.onnx differ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx new file mode 100644 index 0000000000000..a83c21030ad67 Binary files /dev/null and b/onnxruntime/test/testdata/test_dangling_input_segment_ids.onnx differ diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py new file mode 100644 index 0000000000000..c5eb8a600d6b5 --- /dev/null +++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py @@ -0,0 +1,86 @@ +""" +Run this script to recreate the original onnx model. +Example usage: +python test_dangling_input_segment_ids.py out_model_path.onnx +""" + +import os +import sys + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + +DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids") + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +model = helper.make_model( + opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)], + ir_version=7, + graph=make_graph( + name="embed_layernorm_graph", + inputs=[ + helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]), + ], + outputs=[ + helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]), + helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]), + ], + initializer=[ + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]), + name="word_embed", + ), + numpy_helper.from_array( + np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]), + name="pos_embed", + ), + numpy_helper.from_array( + np.array( + [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495], + dtype="float32", + ), + name="gamma", + ), + numpy_helper.from_array( + np.array( + [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32" + ), + name="beta", + ), + ], + nodes=[ + make_node( + "EmbedLayerNormalization", + inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"], + outputs=["layernorm_out", "mask_index_out"], + domain="com.microsoft", + ) + ], + ), +) + +if __name__ == "__main__" and len(sys.argv) == 2: + _, out_path = sys.argv + onnx.save(model, out_path)