2323
2424import io .bioimage .modelrunner .tensor .shm .SharedMemoryArray ;
2525import io .bioimage .modelrunner .utils .CommonUtils ;
26- import net .imglib2 .RandomAccessibleInterval ;
27- import net .imglib2 .img .Img ;
28- import net .imglib2 .type .numeric .integer .IntType ;
29- import net .imglib2 .type .numeric .integer .LongType ;
30- import net .imglib2 .type .numeric .integer .UnsignedByteType ;
31- import net .imglib2 .type .numeric .real .DoubleType ;
32- import net .imglib2 .type .numeric .real .FloatType ;
3326import net .imglib2 .util .Cast ;
3427
3528import java .nio .ByteBuffer ;
4235import org .bytedeco .pytorch .Tensor ;
4336
4437/**
45- * A TensorFlow 2 {@link Tensor} builder from {@link Img} and
46- * {@link io.bioimage.modelrunner.tensor.Tensor} objects.
38+ * Utility class to build Pytorch Bytedeco tensors from shm segments using {@link SharedMemoryArray}
4739 *
48- * @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
40+ * @author Carlos Garcia Lopez de Haro
4941 */
5042public final class TensorBuilder {
5143
@@ -55,19 +47,16 @@ public final class TensorBuilder {
5547 private TensorBuilder () {}
5648
5749 /**
58- * Creates {@link TType} instance with the same size and information as the
59- * given {@link RandomAccessibleInterval}.
50+ * Creates {@link Tensor} instance from a {@link SharedMemoryArray}
6051 *
61- * @param <T>
62- * the ImgLib2 data types the {@link RandomAccessibleInterval} can be
6352 * @param array
64- * the {@link RandomAccessibleInterval } that is going to be converted into
65- * a {@link TType } tensor
66- * @return a {@link TType} tensor
67- * @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval }
53+ * the {@link SharedMemoryArray } that is going to be converted into
54+ * a {@link Tensor } tensor
55+ * @return the Pytorch {@link Tensor} as the one stored in the shared memory segment
56+ * @throws IllegalArgumentException if the type of the {@link SharedMemoryArray }
6857 * is not supported
6958 */
70- public static org . bytedeco . pytorch . Tensor build (SharedMemoryArray array ) throws IllegalArgumentException
59+ public static Tensor build (SharedMemoryArray array ) throws IllegalArgumentException
7160 {
7261 // Create an Icy sequence of the same type of the tensor
7362 if (array .getOriginalDataType ().equals ("int8" )) {
@@ -90,131 +79,81 @@ else if (array.getOriginalDataType().equals("int64")) {
9079 }
9180 }
9281
93- /**
94- * Creates a {@link TType} tensor of type {@link TUint8} from an
95- * {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
96- *
97- * @param tensor
98- * The {@link RandomAccessibleInterval} to fill the tensor with.
99- * @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
100- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
101- * not compatible
102- */
103- public static org .bytedeco .pytorch .Tensor buildByte (SharedMemoryArray tensor )
82+ private static Tensor buildByte (SharedMemoryArray shmArray )
10483 throws IllegalArgumentException
10584 {
106- long [] ogShape = tensor .getOriginalShape ();
85+ long [] ogShape = shmArray .getOriginalShape ();
10786 if (CommonUtils .int32Overflows (ogShape , 1 ))
10887 throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
10988 + " is too big. Max number of elements per ubyte tensor supported: " + Integer .MAX_VALUE );
110- if (!tensor .isNumpyFormat ())
89+ if (!shmArray .isNumpyFormat ())
11190 throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
112- ByteBuffer buff = tensor .getDataBufferNoHeader ();
91+ ByteBuffer buff = shmArray .getDataBufferNoHeader ();
11392 Tensor ndarray = Tensor .create (buff .array (), ogShape );
11493 return ndarray ;
11594 }
11695
117- /**
118- * Creates a {@link TInt32} tensor of type {@link TInt32} from an
119- * {@link RandomAccessibleInterval} of type {@link IntType}
120- *
121- * @param tensor
122- * The {@link RandomAccessibleInterval} to fill the tensor with.
123- * @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
124- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
125- * not compatible
126- */
127- public static Tensor buildInt (SharedMemoryArray tensor )
96+ private static Tensor buildInt (SharedMemoryArray shmaArray )
12897 throws IllegalArgumentException
12998 {
130- long [] ogShape = tensor .getOriginalShape ();
99+ long [] ogShape = shmaArray .getOriginalShape ();
131100 if (CommonUtils .int32Overflows (ogShape , 1 ))
132101 throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
133102 + " is too big. Max number of elements per ubyte tensor supported: " + Integer .MAX_VALUE );
134- if (!tensor .isNumpyFormat ())
103+ if (!shmaArray .isNumpyFormat ())
135104 throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
136- ByteBuffer buff = tensor .getDataBufferNoHeader ();
105+ ByteBuffer buff = shmaArray .getDataBufferNoHeader ();
137106 IntBuffer intBuff = buff .asIntBuffer ();
138107 int [] intArray = new int [intBuff .capacity ()];
139108 intBuff .get (intArray );
140109 Tensor ndarray = Tensor .create (intBuff .array (), ogShape );
141110 return ndarray ;
142111 }
143112
144- /**
145- * Creates a {@link TInt64} tensor of type {@link TInt64} from an
146- * {@link RandomAccessibleInterval} of type {@link LongType}
147- *
148- * @param tensor
149- * The {@link RandomAccessibleInterval} to fill the tensor with.
150- * @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
151- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
152- * not compatible
153- */
154- private static org .bytedeco .pytorch .Tensor buildLong (SharedMemoryArray tensor )
113+ private static org .bytedeco .pytorch .Tensor buildLong (SharedMemoryArray shmArray )
155114 throws IllegalArgumentException
156115 {
157- long [] ogShape = tensor .getOriginalShape ();
116+ long [] ogShape = shmArray .getOriginalShape ();
158117 if (CommonUtils .int32Overflows (ogShape , 1 ))
159118 throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
160119 + " is too big. Max number of elements per ubyte tensor supported: " + Integer .MAX_VALUE );
161- if (!tensor .isNumpyFormat ())
120+ if (!shmArray .isNumpyFormat ())
162121 throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
163- ByteBuffer buff = tensor .getDataBufferNoHeader ();
122+ ByteBuffer buff = shmArray .getDataBufferNoHeader ();
164123 LongBuffer longBuff = buff .asLongBuffer ();
165124 long [] longArray = new long [longBuff .capacity ()];
166125 longBuff .get (longArray );
167126 Tensor ndarray = Tensor .create (longBuff .array (), ogShape );
168127 return ndarray ;
169128 }
170129
171- /**
172- * Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
173- * {@link RandomAccessibleInterval} of type {@link FloatType}
174- *
175- * @param tensor
176- * The {@link RandomAccessibleInterval} to fill the tensor with.
177- * @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
178- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
179- * not compatible
180- */
181- public static org .bytedeco .pytorch .Tensor buildFloat (SharedMemoryArray tensor )
130+ private static org .bytedeco .pytorch .Tensor buildFloat (SharedMemoryArray shmArray )
182131 throws IllegalArgumentException
183132 {
184- long [] ogShape = tensor .getOriginalShape ();
133+ long [] ogShape = shmArray .getOriginalShape ();
185134 if (CommonUtils .int32Overflows (ogShape , 1 ))
186135 throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
187136 + " is too big. Max number of elements per ubyte tensor supported: " + Integer .MAX_VALUE );
188- if (!tensor .isNumpyFormat ())
137+ if (!shmArray .isNumpyFormat ())
189138 throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
190- ByteBuffer buff = tensor .getDataBufferNoHeader ();
139+ ByteBuffer buff = shmArray .getDataBufferNoHeader ();
191140 FloatBuffer floatBuff = buff .asFloatBuffer ();
192141 float [] floatArray = new float [floatBuff .capacity ()];
193142 floatBuff .get (floatArray );
194143 Tensor ndarray = Tensor .create (floatBuff .array (), ogShape );
195144 return ndarray ;
196145 }
197146
198- /**
199- * Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
200- * {@link RandomAccessibleInterval} of type {@link DoubleType}
201- *
202- * @param tensor
203- * The {@link RandomAccessibleInterval} to fill the tensor with.
204- * @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
205- * @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
206- * not compatible
207- */
208- private static org .bytedeco .pytorch .Tensor buildDouble (SharedMemoryArray tensor )
147+ private static org .bytedeco .pytorch .Tensor buildDouble (SharedMemoryArray shmArray )
209148 throws IllegalArgumentException
210149 {
211- long [] ogShape = tensor .getOriginalShape ();
150+ long [] ogShape = shmArray .getOriginalShape ();
212151 if (CommonUtils .int32Overflows (ogShape , 1 ))
213152 throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
214153 + " is too big. Max number of elements per ubyte tensor supported: " + Integer .MAX_VALUE );
215- if (!tensor .isNumpyFormat ())
154+ if (!shmArray .isNumpyFormat ())
216155 throw new IllegalArgumentException ("Shared memory arrays must be saved in numpy format." );
217- ByteBuffer buff = tensor .getDataBufferNoHeader ();
156+ ByteBuffer buff = shmArray .getDataBufferNoHeader ();
218157 DoubleBuffer doubleBuff = buff .asDoubleBuffer ();
219158 double [] doubleArray = new double [doubleBuff .capacity ()];
220159 doubleBuff .get (doubleArray );
0 commit comments