@@ -25,37 +25,53 @@ namespace neutron {
2525#define ALIGN_SIZE (size ) \
2626 ((size + BUFFER_ALIGNMENT - 1 ) & (~(BUFFER_ALIGNMENT - 1 )))
2727
28+ // clang-format off
2829/* Header schema:
29- +----------------------------------+-----------------------------------+
30- | Input TensorFormats length (1B) | Output TensorFormats length (1B) |
31- +----------------------------------+-----------------------------------+
32- | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
33- +----------------------------------+-----------------------------------+
34- | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
35- +----------------------------------+-----------------------------------+
30+ +----------------------------+-----------------------------+------------------------+
31+ | Neutron inputs length (1B) | Neutron outputs length (1B) | Input args length (1B) |
32+ +----------------------------+-----------+-----------------+------------------------+
33+ | 1st input tensor format (1B) | [nth* input tensor format (1B)] |
34+ +----------------------------------------+------------------------------------------+
35+ | 1st output tensor format (1B) | [nth* output tensor format (1B)] |
36+ +----------------------------------------+------------------------------------------+
37+ | 1st input map (1B) | [nth* input map (1B)] |
38+ +----------------------------------------+------------------------------------------+
39+ | 1st output map (1B) | [nth* output map (1B)] |
40+ +----------------------------------------+------------------------------------------+
3641*/
42+ // clang-format on
3743#define ITEM_SIZE 1 // 1 Byte
3844#define INPUT_TENSOR_FORMAT_LEN_POS 0
3945#define OUTPUT_TENSOR_FORMAT_LEN_POS 1
40- #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 2 * ITEM_SIZE)
46+ #define INPUT_ARGS_LEN_POS 2
47+ #define INPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) (base + 3 * ITEM_SIZE)
4148#define OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (base ) \
42- (base + 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
43- #define PAYLOAD_ADDR (base ) \
44- (base + \
45- ALIGN_SIZE ( \
46- 2 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS] + \
47- base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
49+ (base + 3 * ITEM_SIZE + base[INPUT_TENSOR_FORMAT_LEN_POS])
50+ #define INPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
51+ (base + 3 * ITEM_SIZE + 1 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
52+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
53+ #define OUTPUT_TENSOR_MAP_ARRAY_ADDR (base ) \
54+ (base + 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
55+ 1 * base[OUTPUT_TENSOR_FORMAT_LEN_POS])
56+ #define PAYLOAD_ADDR (base ) \
57+ (base + \
58+ ALIGN_SIZE ( \
59+ 3 * ITEM_SIZE + 2 * base[INPUT_TENSOR_FORMAT_LEN_POS] + \
60+ 2 * base[OUTPUT_TENSOR_FORMAT_LEN_POS]))
4861
4962// Aggregate neutron model handle and data structures into one.
5063typedef struct {
5164 int numInputs = 0 ;
5265 int numOutputs = 0 ;
66+ int numInputArgs = 0 ;
5367 uint32_t scratchSize = 0 ;
5468 NeutronModelConfig mcfg;
5569 NeutronDataConfig dcfg;
5670 NeutronModelHandle nmh = NULL ;
5771 const uint8_t * inputTranspositionFlags;
5872 const uint8_t * outputTranspositionFlags;
73+ const uint8_t * inputMap;
74+ const uint8_t * outputMap;
5975} NeutronConfig;
6076
6177// Applied on outputs.
@@ -210,6 +226,15 @@ void transposeOutput(
210226 }
211227}
212228
229+ bool multipleChannelsPresent (const ArrayRef<exec_aten::SizesType>& sizes) {
230+ size_t length = sizes.size ();
231+ if (length < 3 ) {
232+ return true ;
233+ }
234+ size_t C = sizes[length - 3 ];
235+ return C != 1 ;
236+ }
237+
213238class NeutronBackend final : public PyTorchBackendInterface {
214239 public:
215240 NeutronBackend () {}
@@ -234,17 +259,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
234259 // cfg->mcfg.microcode
235260 // cfg->mcfg.weights
236261 // cfg->mcfg.kernels
237- const uint8_t * transpositionFlags =
262+ const uint8_t * payloadFlags =
238263 static_cast <const uint8_t *>(processed->data ());
239- int numInputs = transpositionFlags [INPUT_TENSOR_FORMAT_LEN_POS];
240- int numOutputs = transpositionFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
241- cfg->inputTranspositionFlags =
242- INPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags );
264+ uint32_t numInputs = payloadFlags [INPUT_TENSOR_FORMAT_LEN_POS];
265+ uint32_t numOutputs = payloadFlags [OUTPUT_TENSOR_FORMAT_LEN_POS];
266+ cfg->numInputArgs = payloadFlags[INPUT_ARGS_LEN_POS];
267+ cfg-> inputTranspositionFlags = INPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags );
243268 cfg->outputTranspositionFlags =
244- OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (transpositionFlags);
269+ OUTPUT_TENSOR_FORMAT_ARRAY_ADDR (payloadFlags);
270+ cfg->inputMap = INPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
271+ cfg->outputMap = OUTPUT_TENSOR_MAP_ARRAY_ADDR (payloadFlags);
245272
246273 const uint32_t * buffer = static_cast <const uint32_t *>(
247- static_cast <const void *> PAYLOAD_ADDR (transpositionFlags ));
274+ static_cast <const void *> PAYLOAD_ADDR (payloadFlags ));
248275 uint32_t magicWord = buffer[0 ];
249276 // Check valid microcode.
250277 if (magicWord != 0x64434D6E ) {
@@ -314,39 +341,37 @@ class NeutronBackend final : public PyTorchBackendInterface {
314341 cfg->dcfg .outputs [cfg->numOutputs ] =
315342 static_cast <void *>(context.allocate (cfg->scratchSize , 16 ));
316343
317- // Set inputs and outputs from args.
344+ // Set inputs from args.
345+ // Transpose inputs if needed.
318346 for (int i = 0 ; i < cfg->numInputs ; i++) {
319- cfg->dcfg .inputs [i] = args[i]->toTensor ().const_data_ptr ();
320- }
321- for (int i = 0 ; i < cfg->numOutputs ; i++) {
322- cfg->dcfg .outputs [i] =
323- args[cfg->numInputs + i]->toTensor ().mutable_data_ptr ();
324- }
325-
326- // Transpose inputs.
327- for (int i = 0 ; i < cfg->numInputs ; i++) {
328- if (cfg->inputTranspositionFlags [i]) {
329- if (args[i]->toTensor ().sizes ().size () < 3 ) {
347+ auto arg = args[cfg->inputMap [i]]->toTensor ();
348+ if (cfg->inputTranspositionFlags [i] &&
349+ multipleChannelsPresent (arg.sizes ())) {
350+ if (arg.sizes ().size () < 3 ) {
330351 ET_LOG (Error, " Unable to transpose 1D and 2D input to channel last" );
331352 return Error::InvalidProgram;
332353 }
333354 // Allocate buffer, the allocator is reset after each PTE instruction.
334- void * buffer = context.allocate (args[i]-> toTensor () .nbytes (), 16 );
355+ void * buffer = context.allocate (arg .nbytes ());
335356 transposeInput (
336- args[i]->toTensor ().const_data_ptr (),
337- buffer,
338- args[i]->toTensor ().sizes (),
339- args[i]->toTensor ().element_size ());
357+ arg.const_data_ptr (), buffer, arg.sizes (), arg.element_size ());
340358 cfg->dcfg .inputs [i] = buffer;
359+ } else {
360+ cfg->dcfg .inputs [i] = arg.const_data_ptr ();
341361 }
342362 }
343- // Redirect outputs.
363+
364+ // Set outputs from args.
365+ // Redirect outputs if needed before transposition.
344366 for (int i = 0 ; i < cfg->numOutputs ; i++) {
345- if (cfg->outputTranspositionFlags [i]) {
367+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
368+ if (cfg->outputTranspositionFlags [i] &&
369+ multipleChannelsPresent (arg.sizes ())) {
346370 // Allocate buffer, the allocator is reset after each PTE instruction.
347- void * buffer =
348- context.allocate (args[cfg->numInputs + i]->toTensor ().nbytes (), 16 );
371+ void * buffer = context.allocate (arg.nbytes ());
349372 cfg->dcfg .outputs [i] = buffer;
373+ } else {
374+ cfg->dcfg .outputs [i] = arg.mutable_data_ptr ();
350375 }
351376 }
352377
@@ -368,17 +393,19 @@ class NeutronBackend final : public PyTorchBackendInterface {
368393
369394 // Transpose outputs.
370395 for (int i = 0 ; i < cfg->numOutputs ; i++) {
371- if (cfg->outputTranspositionFlags [i]) {
372- if (args[cfg->numInputs + i]->toTensor ().sizes ().size () < 3 ) {
396+ auto arg = args[cfg->numInputArgs + cfg->outputMap [i]]->toTensor ();
397+ if (cfg->outputTranspositionFlags [i] &&
398+ multipleChannelsPresent (arg.sizes ())) {
399+ if (arg.sizes ().size () < 3 ) {
373400 ET_LOG (
374401 Error, " Unable to transpose 1D and 2D output to channel first" );
375402 return Error::InvalidProgram;
376403 }
377404 transposeOutput (
378405 cfg->dcfg .outputs [i],
379- args[cfg-> numInputs + i]-> toTensor () .mutable_data_ptr (),
380- args[cfg-> numInputs + i]-> toTensor () .sizes (),
381- args[cfg-> numInputs + i]-> toTensor () .element_size ());
406+ arg .mutable_data_ptr (),
407+ arg .sizes (),
408+ arg .element_size ());
382409 }
383410 }
384411
0 commit comments