|
20 | 20 | */ |
21 | 21 | package io.bioimage.modelrunner.pytorch.javacpp.shm; |
22 | 22 |
|
| 23 | +import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; |
23 | 24 | import io.bioimage.modelrunner.system.PlatformDetection; |
24 | 25 | import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; |
25 | 26 | import io.bioimage.modelrunner.utils.CommonUtils; |
26 | 27 |
|
27 | 28 | import java.io.IOException; |
28 | 29 | import java.nio.ByteBuffer; |
| 30 | +import java.nio.FloatBuffer; |
29 | 31 | import java.util.Arrays; |
30 | 32 |
|
31 | 33 | import org.bytedeco.pytorch.Tensor; |
32 | 34 |
|
33 | 35 | import net.imglib2.type.numeric.integer.IntType; |
34 | 36 | import net.imglib2.type.numeric.integer.LongType; |
| 37 | +import net.imglib2.RandomAccessibleInterval; |
35 | 38 | import net.imglib2.type.numeric.integer.ByteType; |
36 | 39 | import net.imglib2.type.numeric.real.DoubleType; |
37 | 40 | import net.imglib2.type.numeric.real.FloatType; |
@@ -96,7 +99,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws |
96 | 99 | throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape) |
97 | 100 | + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4); |
98 | 101 | SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true); |
99 | | - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); |
| 102 | + RandomAccessibleInterval<?> rai = shma.getSharedRAI(); |
| 103 | + rai = ImgLib2Builder.build(tensor); |
100 | 104 | if (PlatformDetection.isWindows()) shma.close(); |
101 | 105 | } |
102 | 106 |
|
|
0 commit comments