Skip to content

Commit aae46f6

Browse files
committed
correct javadoc
1 parent 27ff872 commit aae46f6

File tree

1 file changed

+28
-89
lines changed

1 file changed

+28
-89
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/TensorBuilder.java

Lines changed: 28 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323

2424
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
2525
import io.bioimage.modelrunner.utils.CommonUtils;
26-
import net.imglib2.RandomAccessibleInterval;
27-
import net.imglib2.img.Img;
28-
import net.imglib2.type.numeric.integer.IntType;
29-
import net.imglib2.type.numeric.integer.LongType;
30-
import net.imglib2.type.numeric.integer.UnsignedByteType;
31-
import net.imglib2.type.numeric.real.DoubleType;
32-
import net.imglib2.type.numeric.real.FloatType;
3326
import net.imglib2.util.Cast;
3427

3528
import java.nio.ByteBuffer;
@@ -42,10 +35,9 @@
4235
import org.bytedeco.pytorch.Tensor;
4336

4437
/**
45-
* A TensorFlow 2 {@link Tensor} builder from {@link Img} and
46-
* {@link io.bioimage.modelrunner.tensor.Tensor} objects.
38+
* Utility class to build Pytorch Bytedeco tensors from shm segments using {@link SharedMemoryArray}
4739
*
48-
* @author Carlos Garcia Lopez de Haro and Daniel Felipe Gonzalez Obando
40+
* @author Carlos Garcia Lopez de Haro
4941
*/
5042
public final class TensorBuilder {
5143

@@ -55,19 +47,16 @@ public final class TensorBuilder {
5547
private TensorBuilder() {}
5648

5749
/**
58-
* Creates {@link TType} instance with the same size and information as the
59-
* given {@link RandomAccessibleInterval}.
50+
* Creates {@link Tensor} instance from a {@link SharedMemoryArray}
6051
*
61-
* @param <T>
62-
* the ImgLib2 data types the {@link RandomAccessibleInterval} can be
6352
* @param array
64-
* the {@link RandomAccessibleInterval} that is going to be converted into
65-
* a {@link TType} tensor
66-
* @return a {@link TType} tensor
67-
* @throws IllegalArgumentException if the type of the {@link RandomAccessibleInterval}
53+
* the {@link SharedMemoryArray} that is going to be converted into
54+
* a {@link Tensor} tensor
55+
* @return the Pytorch {@link Tensor} as the one stored in the shared memory segment
56+
* @throws IllegalArgumentException if the type of the {@link SharedMemoryArray}
6857
* is not supported
6958
*/
70-
public static org.bytedeco.pytorch.Tensor build(SharedMemoryArray array) throws IllegalArgumentException
59+
public static Tensor build(SharedMemoryArray array) throws IllegalArgumentException
7160
{
7261
// Create an Icy sequence of the same type of the tensor
7362
if (array.getOriginalDataType().equals("int8")) {
@@ -90,131 +79,81 @@ else if (array.getOriginalDataType().equals("int64")) {
9079
}
9180
}
9281

93-
/**
94-
* Creates a {@link TType} tensor of type {@link TUint8} from an
95-
* {@link RandomAccessibleInterval} of type {@link UnsignedByteType}
96-
*
97-
* @param tensor
98-
* The {@link RandomAccessibleInterval} to fill the tensor with.
99-
* @return The {@link TType} tensor filled with the {@link RandomAccessibleInterval} data.
100-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
101-
* not compatible
102-
*/
103-
public static org.bytedeco.pytorch.Tensor buildByte(SharedMemoryArray tensor)
82+
private static Tensor buildByte(SharedMemoryArray shmArray)
10483
throws IllegalArgumentException
10584
{
106-
long[] ogShape = tensor.getOriginalShape();
85+
long[] ogShape = shmArray.getOriginalShape();
10786
if (CommonUtils.int32Overflows(ogShape, 1))
10887
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
10988
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
110-
if (!tensor.isNumpyFormat())
89+
if (!shmArray.isNumpyFormat())
11190
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
112-
ByteBuffer buff = tensor.getDataBufferNoHeader();
91+
ByteBuffer buff = shmArray.getDataBufferNoHeader();
11392
Tensor ndarray = Tensor.create(buff.array(), ogShape);
11493
return ndarray;
11594
}
11695

117-
/**
118-
* Creates a {@link TInt32} tensor of type {@link TInt32} from an
119-
* {@link RandomAccessibleInterval} of type {@link IntType}
120-
*
121-
* @param tensor
122-
* The {@link RandomAccessibleInterval} to fill the tensor with.
123-
* @return The {@link TInt32} tensor filled with the {@link RandomAccessibleInterval} data.
124-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
125-
* not compatible
126-
*/
127-
public static Tensor buildInt(SharedMemoryArray tensor)
96+
private static Tensor buildInt(SharedMemoryArray shmaArray)
12897
throws IllegalArgumentException
12998
{
130-
long[] ogShape = tensor.getOriginalShape();
99+
long[] ogShape = shmaArray.getOriginalShape();
131100
if (CommonUtils.int32Overflows(ogShape, 1))
132101
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
133102
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
134-
if (!tensor.isNumpyFormat())
103+
if (!shmaArray.isNumpyFormat())
135104
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
136-
ByteBuffer buff = tensor.getDataBufferNoHeader();
105+
ByteBuffer buff = shmaArray.getDataBufferNoHeader();
137106
IntBuffer intBuff = buff.asIntBuffer();
138107
int[] intArray = new int[intBuff.capacity()];
139108
intBuff.get(intArray);
140109
Tensor ndarray = Tensor.create(intBuff.array(), ogShape);
141110
return ndarray;
142111
}
143112

144-
/**
145-
* Creates a {@link TInt64} tensor of type {@link TInt64} from an
146-
* {@link RandomAccessibleInterval} of type {@link LongType}
147-
*
148-
* @param tensor
149-
* The {@link RandomAccessibleInterval} to fill the tensor with.
150-
* @return The {@link TInt64} tensor filled with the {@link RandomAccessibleInterval} data.
151-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
152-
* not compatible
153-
*/
154-
private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray tensor)
113+
private static org.bytedeco.pytorch.Tensor buildLong(SharedMemoryArray shmArray)
155114
throws IllegalArgumentException
156115
{
157-
long[] ogShape = tensor.getOriginalShape();
116+
long[] ogShape = shmArray.getOriginalShape();
158117
if (CommonUtils.int32Overflows(ogShape, 1))
159118
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
160119
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
161-
if (!tensor.isNumpyFormat())
120+
if (!shmArray.isNumpyFormat())
162121
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
163-
ByteBuffer buff = tensor.getDataBufferNoHeader();
122+
ByteBuffer buff = shmArray.getDataBufferNoHeader();
164123
LongBuffer longBuff = buff.asLongBuffer();
165124
long[] longArray = new long[longBuff.capacity()];
166125
longBuff.get(longArray);
167126
Tensor ndarray = Tensor.create(longBuff.array(), ogShape);
168127
return ndarray;
169128
}
170129

171-
/**
172-
* Creates a {@link TFloat32} tensor of type {@link TFloat32} from an
173-
* {@link RandomAccessibleInterval} of type {@link FloatType}
174-
*
175-
* @param tensor
176-
* The {@link RandomAccessibleInterval} to fill the tensor with.
177-
* @return The {@link TFloat32} tensor filled with the {@link RandomAccessibleInterval} data.
178-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
179-
* not compatible
180-
*/
181-
public static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray tensor)
130+
private static org.bytedeco.pytorch.Tensor buildFloat(SharedMemoryArray shmArray)
182131
throws IllegalArgumentException
183132
{
184-
long[] ogShape = tensor.getOriginalShape();
133+
long[] ogShape = shmArray.getOriginalShape();
185134
if (CommonUtils.int32Overflows(ogShape, 1))
186135
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
187136
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
188-
if (!tensor.isNumpyFormat())
137+
if (!shmArray.isNumpyFormat())
189138
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
190-
ByteBuffer buff = tensor.getDataBufferNoHeader();
139+
ByteBuffer buff = shmArray.getDataBufferNoHeader();
191140
FloatBuffer floatBuff = buff.asFloatBuffer();
192141
float[] floatArray = new float[floatBuff.capacity()];
193142
floatBuff.get(floatArray);
194143
Tensor ndarray = Tensor.create(floatBuff.array(), ogShape);
195144
return ndarray;
196145
}
197146

198-
/**
199-
* Creates a {@link TFloat64} tensor of type {@link TFloat64} from an
200-
* {@link RandomAccessibleInterval} of type {@link DoubleType}
201-
*
202-
* @param tensor
203-
* The {@link RandomAccessibleInterval} to fill the tensor with.
204-
* @return The {@link TFloat64} tensor filled with the {@link RandomAccessibleInterval} data.
205-
* @throws IllegalArgumentException if the input {@link RandomAccessibleInterval} type is
206-
* not compatible
207-
*/
208-
private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray tensor)
147+
private static org.bytedeco.pytorch.Tensor buildDouble(SharedMemoryArray shmArray)
209148
throws IllegalArgumentException
210149
{
211-
long[] ogShape = tensor.getOriginalShape();
150+
long[] ogShape = shmArray.getOriginalShape();
212151
if (CommonUtils.int32Overflows(ogShape, 1))
213152
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
214153
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
215-
if (!tensor.isNumpyFormat())
154+
if (!shmArray.isNumpyFormat())
216155
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
217-
ByteBuffer buff = tensor.getDataBufferNoHeader();
156+
ByteBuffer buff = shmArray.getDataBufferNoHeader();
218157
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
219158
double[] doubleArray = new double[doubleBuff.capacity()];
220159
doubleBuff.get(doubleArray);

0 commit comments

Comments
 (0)