Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 18 additions & 14 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1582,11 +1582,13 @@ inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForInputs(

auto num_inputs = GetInputCount();
std::vector<ConstMemoryInfo> mem_infos;
mem_infos.resize(num_inputs);
if (num_inputs > 0) {
mem_infos.resize(num_inputs);

ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
num_inputs));
ThrowOnError(GetApi().SessionGetMemoryInfoForInputs(this->p_,
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
num_inputs));
}

return mem_infos;
}
Expand All @@ -1598,11 +1600,13 @@ inline std::vector<ConstMemoryInfo> ConstSessionImpl<T>::GetMemoryInfoForOutputs

auto num_outputs = GetOutputCount();
std::vector<ConstMemoryInfo> mem_infos;
mem_infos.resize(num_outputs);
if (num_outputs > 0) {
mem_infos.resize(num_outputs);

ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
num_outputs));
ThrowOnError(GetApi().SessionGetMemoryInfoForOutputs(this->p_,
reinterpret_cast<const OrtMemoryInfo**>(mem_infos.data()),
num_outputs));
}
return mem_infos;
}

Expand Down Expand Up @@ -1631,12 +1635,12 @@ template <typename T>
inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForInputs() const {
auto num_inputs = GetInputCount();
std::vector<ConstEpDevice> input_devices;
input_devices.resize(num_inputs);

ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
reinterpret_cast<const OrtEpDevice**>(input_devices.data()),
num_inputs));

if (num_inputs > 0) {
input_devices.resize(num_inputs);
ThrowOnError(GetApi().SessionGetEpDeviceForInputs(this->p_,
reinterpret_cast<const OrtEpDevice**>(input_devices.data()),
num_inputs));
}
return input_devices;
}

Expand Down
77 changes: 61 additions & 16 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3383,17 +3383,58 @@ common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType

for (const auto* def : def_list) {
InlinedVector<SessionState::NodeInfo> node_info_vec;
Status status;
if (type == SessionInputOutputType::kOutput) {
ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec));
status = session_state_->GetOutputNodeInfo(def->Name(), node_info_vec);
} else {
ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));
status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
}

// all entries are for the same OrtDevice so use the first one.
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
// from the session state and use its OrtMemoryInfo.
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
memory_info.push_back(&allocator->Info());
if (!status.IsOK()) {
if (type == SessionInputOutputType::kInput) {
return status;
}

// Check first if this output is produced by an input that directly
// propagates to output with the same name.
status = session_state_->GetInputNodeInfo(def->Name(), node_info_vec);
if (status.IsOK()) {
// all entries are for the same OrtDevice so use the first one.
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
// from the session state and use its OrtMemoryInfo.
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
memory_info.push_back(&allocator->Info());
} else {
// Check if this output is produced by a constant initializer
// Pick the MemoryInfo from the initializer's OrtValue
const auto& ort_value_map = session_state_->GetOrtValueNameIdxMap();

OrtValueIndex ort_value_index;
status = ort_value_map.GetIdx(def->Name(), ort_value_index);
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Failed to find node output or a constant initializer producing output: ",
def->Name(), ".");
}

const auto& idx_to_ort_value = session_state_->GetInitializedTensors();
auto it = idx_to_ort_value.find(ort_value_index);
if (it == idx_to_ort_value.end()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"Failed to find node output or a constant initializer producing output: ",
def->Name(), ".");
}
const auto& tensor = it->second.Get<Tensor>();
auto allocator = session_state_->GetAllocator(tensor.Location());
memory_info.push_back(&allocator->Info());
}
} else {
// all entries are for the same OrtDevice so use the first one.
// we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice
// from the session state and use its OrtMemoryInfo.
auto allocator = session_state_->GetAllocator(*node_info_vec.front().device);
memory_info.push_back(&allocator->Info());
}
}

return Status::OK();
Expand Down Expand Up @@ -3422,15 +3463,19 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector<const OrtEpD
for (const auto* def : def_list) {
InlinedVector<SessionState::NodeInfo> node_info_vec;
ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec));

// if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map
// instead of doing a linear search each time.
const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType();
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
return entry->ep_name == ep_name;
});

ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
assert(!node_info_vec.empty());
// If we have an input that is not consumed by any node,
// including nodes in subgraphs, then we return nullptr.
const auto* p_node = node_info_vec.front().p_node;
if (p_node != nullptr) {
const auto ep_name = p_node->GetExecutionProviderType();
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
return entry->ep_name == ep_name;
});
ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
} else {
ep_devices.push_back(nullptr);
}
}

return Status::OK();
Expand Down
37 changes: 37 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,35 @@ INSTANTIATE_TEST_SUITE_P(CApiTestWithProviders,
CApiTestWithProvider,
::testing::Values(0, 1, 2, 3, 4));

TEST(CApiTest, TestInputPassThroughToOutput) {
const ORTCHAR_T* model_uri = TSTR("testdata/input_propagated_to_output.onnx");
Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
auto inputs_meminfos = session.GetMemoryInfoForInputs();
ASSERT_EQ(1U, inputs_meminfos.size());
auto inputs_epdevices = session.GetEpDeviceForInputs();
ASSERT_EQ(1U, inputs_epdevices.size());
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
ASSERT_EQ(7U, outputs_meminfos.size());
}

TEST(CApiTest, TestDanglingInput) {
// Here we test an issue with segments_ids that is an input not consumed by anything
// This kind of model is unlikely to be used in practice but we want to make sure it works
const ORTCHAR_T* model_uri = TSTR("testdata/test_dangling_input_segment_ids.onnx");
Ort::Session session(*ort_env, model_uri, Ort::SessionOptions{});
auto inputs_meminfos = session.GetMemoryInfoForInputs();
ASSERT_EQ(2U, inputs_meminfos.size());
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
ASSERT_EQ(2U, outputs_meminfos.size());
auto inputs_epdevices = session.GetEpDeviceForInputs();
ASSERT_EQ(2U, inputs_epdevices.size());
// One of the devices returning is null since the input is not consumed
// there is not a device for it.
const bool null_present = std::any_of(inputs_epdevices.begin(), inputs_epdevices.end(),
[](const auto& device) { return device == nullptr; });
ASSERT_TRUE(null_present);
}

#if !defined(DISABLE_SPARSE_TENSORS)
TEST(CApiTest, SparseOutputModel) {
std::vector<int64_t> dense_shape{3, 3};
Expand All @@ -505,7 +534,15 @@ TEST(CApiTest, SparseOutputModel) {
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;
const char* const output_names[] = {"values"};
// This model produces a sparse output from a constant sparse initializer
Ort::Session session(*ort_env, SPARSE_OUTPUT_MODEL_URI, Ort::SessionOptions{});
auto inputs_meminfos = session.GetMemoryInfoForInputs();
ASSERT_TRUE(inputs_meminfos.empty());
auto outputs_meminfos = session.GetMemoryInfoForOutputs();
ASSERT_EQ(1U, outputs_meminfos.size());
auto inputs_epdevices = session.GetEpDeviceForInputs();
ASSERT_TRUE(inputs_epdevices.empty());

auto ort_outputs = session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
output_names, 1);
ASSERT_EQ(ort_outputs.size(), 1U);
Expand Down
Binary file not shown.
Binary file not shown.
86 changes: 86 additions & 0 deletions onnxruntime/test/testdata/test_dangling_input_segment_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Run this script to recreate the original onnx model.
Example usage:
python test_dangling_input_segment_ids.py out_model_path.onnx
"""

import os
import sys

import numpy as np
import onnx

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note test

Module 'onnx' is imported with both 'import' and 'import from'.
Module 'onnxruntime.test.onnx' is imported with both 'import' and 'import from'.

Copilot Autofix

AI 2 months ago

To fix the problem, remove the line from onnx import TensorProto, helper, numpy_helper (line 12) and replace all occurrences of TensorProto, helper, and numpy_helper with onnx.TensorProto, onnx.helper, and onnx.numpy_helper respectively, throughout the code snippet. This ensures code clarity and avoids confusion as recommended, and no APIs should change. The only file to be changed is onnxruntime/test/testdata/test_dangling_input_segment_ids.py. No new imports or definitions are needed, just refactoring for proper qualification.

Suggested changeset 1
onnxruntime/test/testdata/test_dangling_input_segment_ids.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py
--- a/onnxruntime/test/testdata/test_dangling_input_segment_ids.py
+++ b/onnxruntime/test/testdata/test_dangling_input_segment_ids.py
@@ -9,7 +9,6 @@
 
 import numpy as np
 import onnx
-from onnx import TensorProto, helper, numpy_helper
 
 DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids")
 
@@ -20,7 +19,7 @@
 
 
 def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
-    node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
+    node = onnx.helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
     if doc_string == "":
         node.doc_string = ""
     order_repeated_field(node.attribute, "name", kwargs.keys())
@@ -28,42 +27,42 @@
 
 
 def make_graph(*args, doc_string=None, **kwargs):
-    graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
+    graph = onnx.helper.make_graph(*args, doc_string=doc_string, **kwargs)
     if doc_string == "":
         graph.doc_string = ""
     return graph
 
 
-model = helper.make_model(
-    opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)],
+model = onnx.helper.make_model(
+    opset_imports=[onnx.helper.make_operatorsetid("", 14), onnx.helper.make_operatorsetid("com.microsoft", 1)],
     ir_version=7,
     graph=make_graph(
         name="embed_layernorm_graph",
         inputs=[
-            helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]),
-            helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]),
+            onnx.helper.make_tensor_value_info("input_ids", onnx.TensorProto.INT32, shape=[1, 4]),
+            onnx.helper.make_tensor_value_info("segment_ids", onnx.TensorProto.INT32, shape=[1, 4]),
         ],
         outputs=[
-            helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]),
-            helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]),
+            onnx.helper.make_tensor_value_info("layernorm_out", onnx.TensorProto.FLOAT, shape=[1, 4, 4]),
+            onnx.helper.make_tensor_value_info("mask_index_out", onnx.TensorProto.INT32, shape=[1]),
         ],
         initializer=[
-            numpy_helper.from_array(
+            onnx.numpy_helper.from_array(
                 np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]),
                 name="word_embed",
             ),
-            numpy_helper.from_array(
+            onnx.numpy_helper.from_array(
                 np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]),
                 name="pos_embed",
             ),
-            numpy_helper.from_array(
+            onnx.numpy_helper.from_array(
                 np.array(
                     [0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495],
                     dtype="float32",
                 ),
                 name="gamma",
             ),
-            numpy_helper.from_array(
+            onnx.numpy_helper.from_array(
                 np.array(
                     [0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32"
                 ),
EOF
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from onnx import TensorProto, helper, numpy_helper

DATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_dangling_input_segment_ids")


def order_repeated_field(repeated_proto, key_name, order):
order = list(order)
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))


def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
if doc_string == "":
node.doc_string = ""
order_repeated_field(node.attribute, "name", kwargs.keys())
return node


def make_graph(*args, doc_string=None, **kwargs):
graph = helper.make_graph(*args, doc_string=doc_string, **kwargs)
if doc_string == "":
graph.doc_string = ""
return graph


model = helper.make_model(
opset_imports=[helper.make_operatorsetid("", 14), helper.make_operatorsetid("com.microsoft", 1)],
ir_version=7,
graph=make_graph(
name="embed_layernorm_graph",
inputs=[
helper.make_tensor_value_info("input_ids", TensorProto.INT32, shape=[1, 4]),
helper.make_tensor_value_info("segment_ids", TensorProto.INT32, shape=[1, 4]),
],
outputs=[
helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, shape=[1, 4, 4]),
helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, shape=[1]),
],
initializer=[
numpy_helper.from_array(
np.load(os.path.join(DATA_DIR, "const0_word_embed.npy")).astype("float32").reshape([32, 4]),
name="word_embed",
),
numpy_helper.from_array(
np.load(os.path.join(DATA_DIR, "const1_pos_embed.npy")).astype("float32").reshape([16, 4]),
name="pos_embed",
),
numpy_helper.from_array(
np.array(
[0.6185135841369629, 0.010364261455833912, 0.5386272668838501, 0.0030179566238075495],
dtype="float32",
),
name="gamma",
),
numpy_helper.from_array(
np.array(
[0.9511938095092773, 0.9054020047187805, 0.7959669232368469, 0.9152743220329285], dtype="float32"
),
name="beta",
),
],
nodes=[
make_node(
"EmbedLayerNormalization",
inputs=["input_ids", "", "word_embed", "pos_embed", "", "gamma", "beta"],
outputs=["layernorm_out", "mask_index_out"],
domain="com.microsoft",
)
],
),
)

if __name__ == "__main__" and len(sys.argv) == 2:
_, out_path = sys.argv
onnx.save(model, out_path)
Loading