2020 */
2121package io .bioimage .modelrunner .pytorch .javacpp .shm ;
2222
23- import io .bioimage .modelrunner .pytorch .javacpp .tensor .ImgLib2Builder ;
2423import io .bioimage .modelrunner .system .PlatformDetection ;
2524import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
2625import io .bioimage .modelrunner .utils .CommonUtils ;
2726
2827import java .io .IOException ;
2928import java .nio .ByteBuffer ;
29+ import java .nio .DoubleBuffer ;
3030import java .nio .FloatBuffer ;
31+ import java .nio .IntBuffer ;
32+ import java .nio .LongBuffer ;
3133import java .util .Arrays ;
3234
3335import org .bytedeco .pytorch .Tensor ;
3436
3537import net .imglib2 .type .numeric .integer .IntType ;
3638import net .imglib2 .type .numeric .integer .LongType ;
37- import net .imglib2 .RandomAccessibleInterval ;
3839import net .imglib2 .type .numeric .integer .ByteType ;
3940import net .imglib2 .type .numeric .real .DoubleType ;
4041import net .imglib2 .type .numeric .real .FloatType ;
@@ -88,7 +89,14 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
8889 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
8990 + " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
9091 SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new ByteType (), false , true );
91- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
92+ long flatSize = 1 ;
93+ for (long l : arrayShape ) {flatSize *= l ;}
94+ byte [] flat = new byte [(int ) flatSize ];
95+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize ));
96+ tensor .data_ptr_byte ().get (flat );
97+ byteBuffer .put (flat );
98+ byteBuffer .rewind ();
99+ shma .getDataBufferNoHeader ().put (byteBuffer );
92100 if (PlatformDetection .isWindows ()) shma .close ();
93101 }
94102
@@ -99,8 +107,15 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
99107 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
100108 + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
101109 SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
102- RandomAccessibleInterval <?> rai = shma .getSharedRAI ();
103- rai = ImgLib2Builder .build (tensor );
110+ long flatSize = 1 ;
111+ for (long l : arrayShape ) {flatSize *= l ;}
112+ int [] flat = new int [(int ) flatSize ];
113+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Integer .BYTES ));
114+ IntBuffer floatBuffer = byteBuffer .asIntBuffer ();
115+ tensor .data_ptr_int ().get (flat );
116+ floatBuffer .put (flat );
117+ byteBuffer .rewind ();
118+ shma .getDataBufferNoHeader ().put (byteBuffer );
104119 if (PlatformDetection .isWindows ()) shma .close ();
105120 }
106121
@@ -130,7 +145,15 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
130145 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
131146 + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
132147 SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
133- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
148+ long flatSize = 1 ;
149+ for (long l : arrayShape ) {flatSize *= l ;}
150+ double [] flat = new double [(int ) flatSize ];
151+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Double .BYTES ));
152+ DoubleBuffer floatBuffer = byteBuffer .asDoubleBuffer ();
153+ tensor .data_ptr_double ().get (flat );
154+ floatBuffer .put (flat );
155+ byteBuffer .rewind ();
156+ shma .getDataBufferNoHeader ().put (byteBuffer );
134157 if (PlatformDetection .isWindows ()) shma .close ();
135158 }
136159
@@ -141,7 +164,15 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
141164 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
142165 + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
143166 SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
144- shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
167+ long flatSize = 1 ;
168+ for (long l : arrayShape ) {flatSize *= l ;}
169+ long [] flat = new long [(int ) flatSize ];
170+ ByteBuffer byteBuffer = ByteBuffer .allocateDirect ((int ) (flatSize * Long .BYTES ));
171+ LongBuffer floatBuffer = byteBuffer .asLongBuffer ();
172+ tensor .data_ptr_long ().get (flat );
173+ floatBuffer .put (flat );
174+ byteBuffer .rewind ();
175+ shma .getDataBufferNoHeader ().put (byteBuffer );
145176 if (PlatformDetection .isWindows ()) shma .close ();
146177 }
147178}
0 commit comments