Skip to content

Commit 362f0cc

Browse files
committed
actually copy the bytes from pytorch tensor into the array
1 parent bcfee59 commit 362f0cc

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import io.bioimage.modelrunner.utils.CommonUtils;
2626

2727
import java.io.IOException;
28-
import java.nio.ByteBuffer;
2928
import java.util.Arrays;
3029

3130
import org.bytedeco.pytorch.Tensor;
@@ -81,8 +80,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
8180
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
8281
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
8382
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
84-
ByteBuffer buff = shma.getDataBufferNoHeader();
85-
buff = tensor.asByteBuffer();
83+
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
8684
if (PlatformDetection.isWindows()) shma.close();
8785
}
8886

@@ -93,8 +91,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
9391
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9492
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
9593
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
96-
ByteBuffer buff = shma.getDataBufferNoHeader();
97-
buff = tensor.asByteBuffer();
94+
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
9895
if (PlatformDetection.isWindows()) shma.close();
9996
}
10097

@@ -105,8 +102,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
105102
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
106103
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
107104
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
108-
ByteBuffer buff = shma.getDataBufferNoHeader();
109-
buff = tensor.asByteBuffer();
105+
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
110106
if (PlatformDetection.isWindows()) shma.close();
111107
}
112108

@@ -117,8 +113,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
117113
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
118114
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
119115
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
120-
ByteBuffer buff = shma.getDataBufferNoHeader();
121-
buff = tensor.asByteBuffer();
116+
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
122117
if (PlatformDetection.isWindows()) shma.close();
123118
}
124119

@@ -129,8 +124,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
129124
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
130125
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
131126
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
132-
ByteBuffer buff = shma.getDataBufferNoHeader();
133-
buff = tensor.asByteBuffer();
127+
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
134128
if (PlatformDetection.isWindows()) shma.close();
135129
}
136130
}

0 commit comments

Comments
 (0)