Skip to content

Commit 9d293cc

Browse files
committed
DLPack memoryview support
1 parent b74debb commit 9d293cc

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

docs/api_extra.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,10 @@ convert into an equivalent representation in one of the following frameworks:
11041104

11051105
.. cpp:class:: cupy
11061106

1107+
.. cpp:class:: memview
1108+
1109+
Builtin Python ``memoryview`` for CPU-resident data.
1110+
11071111
Eigen convenience type aliases
11081112
------------------------------
11091113

include/nanobind/ndarray.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ NB_FRAMEWORK(pytorch, 2, "torch.Tensor");
8585
NB_FRAMEWORK(tensorflow, 3, "tensorflow.python.framework.ops.EagerTensor");
8686
NB_FRAMEWORK(jax, 4, "jaxlib.xla_extension.DeviceArray");
8787
NB_FRAMEWORK(cupy, 5, "cupy.ndarray");
88+
NB_FRAMEWORK(memview, 6, "memoryview");
8889

8990
NAMESPACE_BEGIN(device)
9091
NB_DEVICE(none, 0); NB_DEVICE(cpu, 1); NB_DEVICE(cuda, 2);

src/nb_ndarray.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
765765
object o;
766766
if (copy && framework == no_framework::value && th->self) {
767767
o = borrow(th->self);
768-
} else if (framework == numpy::value || framework == jax::value) {
768+
} else if (framework == numpy::value || framework == jax::value || framework == memview::value) {
769769
nb_ndarray *h = PyObject_New(nb_ndarray, nd_ndarray_tp());
770770
if (!h)
771771
return nullptr;
@@ -784,6 +784,8 @@ PyObject *ndarray_export(ndarray_handle *th, int framework,
784784
.attr("array")(o, arg("copy") = copy)
785785
.release()
786786
.ptr();
787+
} else if (framework == memview::value) {
788+
return PyMemoryView_FromObject(o.ptr());
787789
} else {
788790
const char *pkg_name;
789791
switch (framework) {

0 commit comments

Comments
 (0)