|
22 | 22 | #include "core/framework/data_transfer_utils.h" |
23 | 23 | #include "core/framework/data_types_internal.h" |
24 | 24 | #include "core/framework/error_code_helper.h" |
| 25 | +#include "core/framework/plugin_ep_stream.h" |
25 | 26 | #include "core/framework/provider_options_utils.h" |
26 | 27 | #include "core/framework/random_seed.h" |
27 | 28 | #include "core/framework/sparse_tensor.h" |
@@ -1584,6 +1585,18 @@ void addGlobalMethods(py::module& m) { |
1584 | 1585 | }, |
1585 | 1586 | R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); |
1586 | 1587 |
|
| 1588 | + m.def( |
| 1589 | + "copy_tensors", |
| 1590 | + [](const std::vector<const OrtValue*>& src, const std::vector<OrtValue*>& dest, py::object& py_arg) { |
| 1591 | + const OrtEnv* ort_env = GetOrtEnv(); |
| 1592 | + OrtSyncStream* stream = nullptr; |
| 1593 | + if (!py_arg.is_none()) { |
| 1594 | + stream = py_arg.cast<OrtSyncStream*>(); |
| 1595 | + } |
| 1596 | + Ort::ThrowOnError(Ort::GetApi().CopyTensors(ort_env, src.data(), dest.data(), stream, src.size())); |
| 1597 | + }, |
| 1598 | + R"pbdoc("Copy tensors from sources to destinations using specified stream handle (or None))pbdoc"); |
| 1599 | + |
1587 | 1600 | #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) |
1588 | 1601 | m.def( |
1589 | 1602 | "get_available_openvino_device_ids", []() -> std::vector<std::string> { |
@@ -1785,6 +1798,16 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra |
1785 | 1798 | .value("CPU", OrtMemTypeCPU) |
1786 | 1799 | .value("DEFAULT", OrtMemTypeDefault); |
1787 | 1800 |
|
| 1801 | + py::enum_<OrtMemoryInfoDeviceType>(m, "OrtMemoryInfoDeviceType") |
| 1802 | + .value("CPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_CPU) |
| 1803 | + .value("GPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU) |
| 1804 | + .value("NPU", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_NPU) |
| 1805 | + .value("FPGA", OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_FPGA); |
| 1806 | + |
| 1807 | + py::enum_<OrtDeviceMemoryType>(m, "OrtDeviceMemoryType") |
| 1808 | + .value("DEFAULT", OrtDeviceMemoryType_DEFAULT) |
| 1809 | + .value("HOST_ACCESSIBLE", OrtDeviceMemoryType_HOST_ACCESSIBLE); |
| 1810 | + |
1788 | 1811 | py::class_<OrtDevice> device(m, "OrtDevice", R"pbdoc(ONNXRuntime device information.)pbdoc"); |
1789 | 1812 | device.def(py::init<OrtDevice::DeviceType, OrtDevice::MemoryType, OrtDevice::VendorId, OrtDevice::DeviceId>()) |
1790 | 1813 | .def(py::init([](OrtDevice::DeviceType type, |
@@ -1813,6 +1836,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra |
1813 | 1836 | .def("device_id", &OrtDevice::Id, R"pbdoc(Device Id.)pbdoc") |
1814 | 1837 | .def("device_type", &OrtDevice::Type, R"pbdoc(Device Type.)pbdoc") |
1815 | 1838 | .def("vendor_id", &OrtDevice::Vendor, R"pbdoc(Vendor Id.)pbdoc") |
| 1839 | + .def("mem_type", &OrtDevice::MemType, R"pbdoc(Device Memory Type.)pbdoc") |
1816 | 1840 | // generic device types that are typically used with a vendor id. |
1817 | 1841 | .def_static("cpu", []() { return OrtDevice::CPU; }) |
1818 | 1842 | .def_static("gpu", []() { return OrtDevice::GPU; }) |
@@ -1863,36 +1887,55 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra |
1863 | 1887 | }, |
1864 | 1888 | R"pbdoc(Hardware device's metadata as string key/value pairs.)pbdoc"); |
1865 | 1889 |
|
| 1890 | + py::class_<OrtSyncStream> py_sync_stream(m, "OrtSyncStream", |
| 1891 | + R"pbdoc(Represents a synchronization stream for model inference.)pbdoc"); |
| 1892 | + |
1866 | 1893 | py::class_<OrtEpDevice> py_ep_device(m, "OrtEpDevice", |
1867 | 1894 | R"pbdoc(Represents a hardware device that an execution provider supports |
1868 | 1895 | for model inference.)pbdoc"); |
1869 | 1896 | py_ep_device.def_property_readonly( |
1870 | 1897 | "ep_name", |
1871 | | - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, |
| 1898 | + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_name; }, |
1872 | 1899 | R"pbdoc(The execution provider's name.)pbdoc") |
1873 | 1900 | .def_property_readonly( |
1874 | 1901 | "ep_vendor", |
1875 | | - [](OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, |
| 1902 | + [](const OrtEpDevice* ep_device) -> std::string { return ep_device->ep_vendor; }, |
1876 | 1903 | R"pbdoc(The execution provider's vendor name.)pbdoc") |
1877 | 1904 | .def_property_readonly( |
1878 | 1905 | "ep_metadata", |
1879 | | - [](OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
| 1906 | + [](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
1880 | 1907 | return ep_device->ep_metadata.Entries(); |
1881 | 1908 | }, |
1882 | 1909 | R"pbdoc(The execution provider's additional metadata for the OrtHardwareDevice.)pbdoc") |
1883 | 1910 | .def_property_readonly( |
1884 | 1911 | "ep_options", |
1885 | | - [](OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
| 1912 | + [](const OrtEpDevice* ep_device) -> std::map<std::string, std::string> { |
1886 | 1913 | return ep_device->ep_options.Entries(); |
1887 | 1914 | }, |
1888 | 1915 | R"pbdoc(The execution provider's options used to configure the provider to use the OrtHardwareDevice.)pbdoc") |
1889 | 1916 | .def_property_readonly( |
1890 | 1917 | "device", |
1891 | | - [](OrtEpDevice* ep_device) -> const OrtHardwareDevice& { |
| 1918 | + [](const OrtEpDevice* ep_device) -> const OrtHardwareDevice& { |
1892 | 1919 | return *ep_device->device; |
1893 | 1920 | }, |
1894 | 1921 | R"pbdoc(The OrtHardwareDevice instance for the OrtEpDevice.)pbdoc", |
1895 | | - py::return_value_policy::reference_internal); |
| 1922 | + py::return_value_policy::reference_internal) |
| 1923 | + .def( |
| 1924 | + "memory_info", |
| 1925 | + [](const OrtEpDevice* ep_device, OrtDeviceMemoryType memory_type) -> const OrtMemoryInfo* { |
| 1926 | + Ort::ConstEpDevice ep_dev(ep_device); |
| 1927 | + return static_cast<const OrtMemoryInfo*>(ep_dev.GetMemoryInfo(memory_type)); |
| 1928 | + }, |
| 1929 | + R"pbdoc(The OrtMemoryInfo instance for the OrtEpDevice specific to the device memory type.)pbdoc", |
| 1930 | + py::return_value_policy::reference_internal) |
| 1931 | + .def( |
| 1932 | + "create_sync_stream", |
| 1933 | + [](const OrtEpDevice* ep_device) -> std::unique_ptr<OrtSyncStream> { |
| 1934 | + Ort::ConstEpDevice ep_dev(ep_device); |
| 1935 | + Ort::SyncStream stream = ep_dev.CreateSyncStream(); |
| 1936 | + return std::unique_ptr<OrtSyncStream>(stream.release()); |
| 1937 | + }, |
| 1938 | + R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); |
1896 | 1939 |
|
1897 | 1940 | py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg"); |
1898 | 1941 | // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. |
@@ -1938,25 +1981,28 @@ for model inference.)pbdoc"); |
1938 | 1981 | .def_readwrite("max_power_of_two_extend_bytes", &OrtArenaCfg::max_power_of_two_extend_bytes); |
1939 | 1982 |
|
1940 | 1983 | py::class_<OrtMemoryInfo> ort_memory_info_binding(m, "OrtMemoryInfo"); |
1941 | | - ort_memory_info_binding.def(py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { |
1942 | | - if (strcmp(name, onnxruntime::CPU) == 0) { |
1943 | | - return std::make_unique<OrtMemoryInfo>(onnxruntime::CPU, type, OrtDevice(), mem_type); |
1944 | | - } else if (strcmp(name, onnxruntime::CUDA) == 0) { |
1945 | | - return std::make_unique<OrtMemoryInfo>( |
1946 | | - onnxruntime::CUDA, type, |
1947 | | - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, |
1948 | | - static_cast<OrtDevice::DeviceId>(id)), |
1949 | | - mem_type); |
1950 | | - } else if (strcmp(name, onnxruntime::CUDA_PINNED) == 0) { |
1951 | | - return std::make_unique<OrtMemoryInfo>( |
1952 | | - onnxruntime::CUDA_PINNED, type, |
1953 | | - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE, OrtDevice::VendorIds::NVIDIA, |
1954 | | - static_cast<OrtDevice::DeviceId>(id)), |
1955 | | - mem_type); |
1956 | | - } else { |
1957 | | - throw std::runtime_error("Specified device is not supported."); |
1958 | | - } |
1959 | | - })); |
| 1984 | + ort_memory_info_binding.def( |
| 1985 | + py::init([](const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) { |
| 1986 | + Ort::MemoryInfo result(name, type, id, mem_type); |
| 1987 | + return std::unique_ptr<OrtMemoryInfo>(result.release()); |
| 1988 | + })) |
| 1989 | + .def_static( |
| 1990 | + "create_v2", |
| 1991 | + [](const char* name, OrtMemoryInfoDeviceType device_type, uint32_t vendor_id, |
| 1992 | + int32_t device_id, OrtDeviceMemoryType device_mem_type, size_t alignment, OrtAllocatorType type) { |
| 1993 | + Ort::MemoryInfo result(name, device_type, vendor_id, device_id, device_mem_type, alignment, type); |
| 1994 | + return std::unique_ptr<OrtMemoryInfo>(result.release()); |
| 1995 | + }, |
| 1996 | + R"pbdoc(Create an OrtMemoryInfo instance using CreateMemoryInfo_V2())pbdoc") |
| 1997 | + .def_property_readonly("name", [](const OrtMemoryInfo* mem_info) -> std::string { return mem_info->name; }, R"pbdoc(Arbitrary name supplied by the user)pbdoc") |
| 1998 | + .def_property_readonly("device_id", [](const OrtMemoryInfo* mem_info) -> int { return mem_info->device.Id(); }, R"pbdoc(Device Id.)pbdoc") |
| 1999 | + .def_property_readonly("mem_type", [](const OrtMemoryInfo* mem_info) -> OrtMemType { return mem_info->mem_type; }, R"pbdoc(OrtMemoryInfo memory type.)pbdoc") |
| 2000 | + .def_property_readonly("allocator_type", [](const OrtMemoryInfo* mem_info) -> OrtAllocatorType { return mem_info->alloc_type; }, R"pbdoc(Allocator type)pbdoc") |
| 2001 | + .def_property_readonly("device_mem_type", [](const OrtMemoryInfo* mem_info) -> OrtDeviceMemoryType { |
| 2002 | + auto mem_type = mem_info->device.MemType(); |
| 2003 | + return (mem_type == OrtDevice::MemType::DEFAULT) ? |
| 2004 | + OrtDeviceMemoryType_DEFAULT: OrtDeviceMemoryType_HOST_ACCESSIBLE ; }, R"pbdoc(Device memory type (Device or Host accessible).)pbdoc") |
| 2005 | + .def_property_readonly("device_vendor_id", [](const OrtMemoryInfo* mem_info) -> uint32_t { return mem_info->device.Vendor(); }); |
1960 | 2006 |
|
1961 | 2007 | py::class_<PySessionOptions> |
1962 | 2008 | sess(m, "SessionOptions", R"pbdoc(Configuration information for a session.)pbdoc"); |
@@ -2653,6 +2699,33 @@ including arg name, arg type (contains both type and shape).)pbdoc") |
2653 | 2699 | auto res = sess->GetSessionHandle()->GetModelMetadata(); |
2654 | 2700 | OrtPybindThrowIfError(res.first); |
2655 | 2701 | return *(res.second); }, py::return_value_policy::reference_internal) |
| 2702 | + .def_property_readonly("input_meminfos", [](const PyInferenceSession* sess) -> py::list { |
| 2703 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2704 | + auto inputs_mem_info = session.GetMemoryInfoForInputs(); |
| 2705 | + py::list result; |
| 2706 | + for (const auto& info : inputs_mem_info) { |
| 2707 | + const auto* p_info = static_cast<const OrtMemoryInfo*>(info); |
| 2708 | + result.append(py::cast(p_info, py::return_value_policy::reference)); |
| 2709 | + } |
| 2710 | + return result; }) |
| 2711 | + .def_property_readonly("output_meminfos", [](const PyInferenceSession* sess) -> py::list { |
| 2712 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2713 | + auto outputs_mem_info = session.GetMemoryInfoForOutputs(); |
| 2714 | + py::list result; |
| 2715 | + for (const auto& info : outputs_mem_info) { |
| 2716 | + const auto* p_info = static_cast<const OrtMemoryInfo*>(info); |
| 2717 | + result.append(py::cast(p_info, py::return_value_policy::reference)); |
| 2718 | + } |
| 2719 | + return result; }) |
| 2720 | + .def_property_readonly("input_epdevices", [](const PyInferenceSession* sess) -> py::list { |
| 2721 | + Ort::ConstSession session(reinterpret_cast<const OrtSession*>(sess->GetSessionHandle())); |
| 2722 | + auto ep_devices = session.GetEpDeviceForInputs(); |
| 2723 | + py::list result; |
| 2724 | + for (const auto& device : ep_devices) { |
| 2725 | + const auto* p_device = static_cast<const OrtEpDevice*>(device); |
| 2726 | + result.append(py::cast(p_device, py::return_value_policy::reference)); |
| 2727 | + } |
| 2728 | + return result; }) |
2656 | 2729 | .def("run_with_iobinding", [](PyInferenceSession* sess, SessionIOBinding& io_binding, RunOptions* run_options = nullptr) -> void { |
2657 | 2730 |
|
2658 | 2731 | Status status; |
|
0 commit comments