@@ -740,3 +740,72 @@ def test_skipped_layers(backend, io_type):
740
740
hls_prediction = hls_model .predict (hls_input ).flatten ()
741
741
742
742
np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 0 , atol = 5e-2 )
743
+
744
+
745
+ @pytest .mark .parametrize ('backend' , ['Vivado' , 'Quartus' ])
746
+ @pytest .mark .parametrize ('io_type' , ['io_parallel' ]) # Only io_parallel for now
747
+ @pytest .mark .parametrize ('tensor_rank' , [2 , 3 ])
748
+ def test_remove_transpose (backend , io_type , tensor_rank ):
749
+ class TestModel (nn .Module ):
750
+ def __init__ (self , tensor_rank ):
751
+ super ().__init__ ()
752
+ if tensor_rank == 2 :
753
+ self .conv1 = nn .Conv1d (in_channels = 1 , out_channels = 4 , kernel_size = 3 , bias = False )
754
+ self .relu1 = nn .ReLU ()
755
+ self .flatten = nn .Flatten ()
756
+ self .fc1 = nn .Linear (in_features = 4 * 6 , out_features = 5 , bias = False )
757
+ self .relu2 = nn .ReLU ()
758
+ else :
759
+ self .conv1 = nn .Conv2d (in_channels = 1 , out_channels = 4 , kernel_size = 3 , bias = False )
760
+ self .relu1 = nn .ReLU ()
761
+ self .flatten = nn .Flatten ()
762
+ self .fc1 = nn .Linear (in_features = 4 * 6 * 6 , out_features = 5 , bias = False )
763
+ self .relu2 = nn .ReLU ()
764
+
765
+ def forward (self , x ):
766
+ # In the hls4ml model, there should be a Transpose node on the input tensor before conv1
767
+ x = self .conv1 (x )
768
+ x = self .relu1 (x )
769
+ x = self .flatten (x ) # This should result in a Transpose node that we aim to remove
770
+ x = self .fc1 (x )
771
+ x = self .relu2 (x )
772
+ return x
773
+
774
+ model = TestModel (tensor_rank = tensor_rank )
775
+ if tensor_rank == 2 :
776
+ input_shape = (1 , 8 )
777
+ input_tensor = torch .randn (10 , 1 , 8 )
778
+ hls_input = np .ascontiguousarray (torch .permute (input_tensor , (0 , 2 , 1 )).detach ().numpy ())
779
+ else :
780
+ input_shape = (1 , 8 , 8 )
781
+ input_tensor = torch .randn (10 , 1 , 8 , 8 )
782
+ hls_input = np .ascontiguousarray (torch .permute (input_tensor , (0 , 2 , 3 , 1 )).detach ().numpy ())
783
+
784
+ batch_input_shape = (None ,) + input_shape
785
+ config = config_from_pytorch_model (
786
+ model ,
787
+ default_precision = 'ap_fixed<32,16>' ,
788
+ inputs_channel_last = False , # Crucial for testing if the first Transpose was removed
789
+ transpose_outputs = False ,
790
+ )
791
+ output_dir = str (test_root_path / f'hls4mlprj_pytorch_api_transpose_nop_{ tensor_rank } d_{ backend } _{ io_type } ' )
792
+ hls_model = convert_from_pytorch_model (
793
+ model ,
794
+ batch_input_shape ,
795
+ hls_config = config ,
796
+ output_dir = output_dir ,
797
+ io_type = io_type ,
798
+ backend = backend ,
799
+ )
800
+
801
+ hls_model .compile ()
802
+
803
+ # Test optimizers removed the two Transpose layers
804
+ transpose_layers = [layer for layer in list (hls_model .get_layers ()) if layer .class_name == 'Transpose' ]
805
+ assert len (transpose_layers ) == 0
806
+
807
+ # Test predictions match
808
+ pytorch_prediction = model (input_tensor ).detach ().numpy ().flatten ()
809
+ hls_prediction = hls_model .predict (hls_input ).flatten ()
810
+
811
+ np .testing .assert_allclose (hls_prediction , pytorch_prediction , rtol = 0 , atol = 5e-2 )
0 commit comments