2929
3030import org .bytedeco .pytorch .Tensor ;
3131
32- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
32+ import net .imglib2 .type .numeric .integer .IntType ;
33+ import net .imglib2 .type .numeric .integer .LongType ;
34+ import net .imglib2 .type .numeric .integer .ByteType ;
35+ import net .imglib2 .type .numeric .real .DoubleType ;
36+ import net .imglib2 .type .numeric .real .FloatType ;
3337
3438/**
3539 * A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for
@@ -79,7 +83,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
7983 if (CommonUtils .int32Overflows (arrayShape , 1 ))
8084 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
8185 + " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
82- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
86+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new ByteType (), false , true );
8387 shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
8488 if (PlatformDetection .isWindows ()) shma .close ();
8589 }
@@ -90,7 +94,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
9094 if (CommonUtils .int32Overflows (arrayShape , 4 ))
9195 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
9296 + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
93- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
97+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new IntType (), false , true );
9498 shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
9599 if (PlatformDetection .isWindows ()) shma .close ();
96100 }
@@ -101,7 +105,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
101105 if (CommonUtils .int32Overflows (arrayShape , 4 ))
102106 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
103107 + " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
104- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
108+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new FloatType (), false , true );
105109 shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
106110 if (PlatformDetection .isWindows ()) shma .close ();
107111 }
@@ -112,7 +116,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
112116 if (CommonUtils .int32Overflows (arrayShape , 8 ))
113117 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
114118 + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
115- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
119+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new DoubleType (), false , true );
116120 shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
117121 if (PlatformDetection .isWindows ()) shma .close ();
118122 }
@@ -123,7 +127,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
123127 if (CommonUtils .int32Overflows (arrayShape , 8 ))
124128 throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
125129 + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
126- SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new UnsignedByteType (), false , true );
130+ SharedMemoryArray shma = SharedMemoryArray .readOrCreate (memoryName , arrayShape , new LongType (), false , true );
127131 shma .getDataBufferNoHeader ().put (tensor .asByteBuffer ());
128132 if (PlatformDetection .isWindows ()) shma .close ();
129133 }
0 commit comments