Skip to content

Commit 41c0519

Browse files
committed
correct support for shm to tensor creation
1 parent 050116e commit 41c0519

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.bioimage.modelrunner.utils.CommonUtils;
2626

2727
import java.io.IOException;
28+
import java.nio.ByteBuffer;
2829
import java.util.Arrays;
2930

3031
import org.bytedeco.pytorch.Tensor;
@@ -80,7 +81,8 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
8081
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
8182
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
8283
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
83-
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
84+
ByteBuffer buff = shma.getDataBufferNoHeader();
85+
buff = tensor.asByteBuffer();
8486
if (PlatformDetection.isWindows()) shma.close();
8587
}
8688

@@ -91,7 +93,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
9193
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9294
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
9395
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
94-
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
96+
ByteBuffer buff = shma.getDataBufferNoHeader();
97+
buff = tensor.asByteBuffer();
9598
if (PlatformDetection.isWindows()) shma.close();
9699
}
97100

@@ -102,7 +105,8 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
102105
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
103106
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
104107
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
105-
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
108+
ByteBuffer buff = shma.getDataBufferNoHeader();
109+
buff = tensor.asByteBuffer();
106110
if (PlatformDetection.isWindows()) shma.close();
107111
}
108112

@@ -113,7 +117,8 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
113117
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
114118
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
115119
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
116-
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
120+
ByteBuffer buff = shma.getDataBufferNoHeader();
121+
buff = tensor.asByteBuffer();
117122
if (PlatformDetection.isWindows()) shma.close();
118123
}
119124

@@ -124,7 +129,8 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
124129
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
125130
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
126131
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
127-
tensor.data_ptr_byte().get(shma.getDataBufferNoHeader().array());
132+
ByteBuffer buff = shma.getDataBufferNoHeader();
133+
buff = tensor.asByteBuffer();
128134
if (PlatformDetection.isWindows()) shma.close();
129135
}
130136
}

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ private static Tensor buildByte(SharedMemoryArray shmArray)
8585
if (!shmArray.isNumpyFormat())
8686
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
8787
ByteBuffer buff = shmArray.getDataBufferNoHeader();
88-
Tensor ndarray = Tensor.create(buff.array(), ogShape);
88+
byte[] flat = new byte[buff.capacity()];
89+
buff.get(flat);
90+
buff.rewind();
91+
Tensor ndarray = Tensor.create(flat, ogShape);
8992
return ndarray;
9093
}
9194

@@ -99,7 +102,10 @@ private static Tensor buildInt(SharedMemoryArray shmaArray)
99102
if (!shmaArray.isNumpyFormat())
100103
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
101104
ByteBuffer buff = shmaArray.getDataBufferNoHeader();
102-
Tensor ndarray = Tensor.create(buff.asIntBuffer().array(), ogShape);
105+
int[] flat = new int[buff.capacity() / 4];
106+
buff.asIntBuffer().get(flat);
107+
buff.rewind();
108+
Tensor ndarray = Tensor.create(flat, ogShape);
103109
return ndarray;
104110
}
105111

@@ -113,7 +119,10 @@ private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray shmArray)
113119
if (!shmArray.isNumpyFormat())
114120
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
115121
ByteBuffer buff = shmArray.getDataBufferNoHeader();
116-
Tensor ndarray = Tensor.create(buff.asLongBuffer().array(), ogShape);
122+
long[] flat = new long[buff.capacity() / 8];
123+
buff.asLongBuffer().get(flat);
124+
buff.rewind();
125+
Tensor ndarray = Tensor.create(flat, ogShape);
117126
return ndarray;
118127
}
119128

@@ -127,7 +136,10 @@ private static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray shmArray
127136
if (!shmArray.isNumpyFormat())
128137
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
129138
ByteBuffer buff = shmArray.getDataBufferNoHeader();
130-
Tensor ndarray = Tensor.create(buff.asFloatBuffer().array(), ogShape);
139+
float[] flat = new float[buff.capacity() / 4];
140+
buff.asFloatBuffer().get(flat);
141+
buff.rewind();
142+
Tensor ndarray = Tensor.create(flat, ogShape);
131143
return ndarray;
132144
}
133145

@@ -141,7 +153,10 @@ private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray shmArra
141153
if (!shmArray.isNumpyFormat())
142154
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
143155
ByteBuffer buff = shmArray.getDataBufferNoHeader();
144-
Tensor ndarray = Tensor.create(buff.asDoubleBuffer().array(), ogShape);
156+
double[] flat = new double[buff.capacity() / 8];
157+
buff.asDoubleBuffer().get(flat);
158+
buff.rewind();
159+
Tensor ndarray = Tensor.create(flat, ogShape);
145160
return ndarray;
146161
}
147162
}

0 commit comments

Comments
 (0)