Skip to content

Commit 77ffb94

Browse files
committed
Remove transpose of input if n_chan=1
1 parent 56db25e commit 77ffb94

File tree

3 files changed

+104
-8
lines changed

3 files changed

+104
-8
lines changed

hls4ml/model/optimizer/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
[
3636
'channels_last_converter',
3737
'remove_transpose_before_flatten',
38+
'remove_nop_transpose',
39+
'remove_single_channel_transpose',
3840
'fuse_bias_add',
39-
'remove_useless_transpose',
4041
'expand_layer_group',
4142
'output_rounding_saturation_mode',
4243
'qkeras_factorize_alpha',
Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
1-
from hls4ml.model.layers import Transpose
1+
from hls4ml.model.layers import Input, Transpose
22
from hls4ml.model.optimizer import OptimizerPass
33

44

5-
class RemoveUselessTranspose(OptimizerPass):
5+
class RemoveNopTranspose(OptimizerPass):
6+
"""
7+
Remove a transpose layer if it doesn't do anything to a 1D array. i.e, 1D input and perm = [0]
8+
"""
9+
610
def match(self, node):
711
is_match = isinstance(node, Transpose) and node.get_attr('perm') == [0] # Useless transpose
812
return is_match
913

1014
def transform(self, model, node):
11-
"""
12-
Remove a transpose layer if it doesn't do anything. i.e 1D input and perm = [0]
13-
"""
14-
print(f"Unnessary {node.name} in the model, optimizing ...")
15+
print(f'Unnecessary transpose node ({node.name}) detected, optimizing ...')
1516
if not node.get_output_nodes():
16-
print(f"WARNING: {node.name} is the output layer! No rewiring performed.")
17+
print(f'WARNING: {node.name} is the output layer! No rewiring performed.')
1718
model.remove_node(node, rewire=False) # Don't rewire if there is no output layer
1819
else:
1920
model.remove_node(node, rewire=True)
2021

2122
return True
23+
24+
25+
class RemoveSingleChannelTranspose(OptimizerPass):
26+
"""
27+
Remove transpose of inputs if the number of channels is 1 as for io_parallel this doesn't affect the array
28+
representation used
29+
"""
30+
31+
def match(self, node):
32+
if node.model.config.get_config_value('IOType') != 'io_parallel':
33+
return False
34+
35+
return (
36+
isinstance(node, Transpose)
37+
and isinstance(node.get_input_node(), Input)
38+
and node.get_input_variable().shape[0] == 1
39+
)
40+
41+
def transform(self, model, node):
42+
# Adjust the input shape and remove the Transpose node
43+
input_var = node.get_input_variable()
44+
input_var.shape.append(input_var.shape.pop(0))
45+
model.remove_node(node)
46+
47+
return True

test/pytest/test_pytorch_api.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,72 @@ def test_skipped_layers(backend, io_type):
740740
hls_prediction = hls_model.predict(hls_input).flatten()
741741

742742
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)
743+
744+
745+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
746+
@pytest.mark.parametrize('io_type', ['io_parallel']) # Only io_parallel for now
747+
@pytest.mark.parametrize('tensor_rank', [2, 3])
748+
def test_remove_transpose(backend, io_type, tensor_rank):
749+
class TestModel(nn.Module):
750+
def __init__(self, tensor_rank):
751+
super().__init__()
752+
if tensor_rank == 2:
753+
self.conv1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
754+
self.relu1 = nn.ReLU()
755+
self.flatten = nn.Flatten()
756+
self.fc1 = nn.Linear(in_features=4 * 6, out_features=5, bias=False)
757+
self.relu2 = nn.ReLU()
758+
else:
759+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
760+
self.relu1 = nn.ReLU()
761+
self.flatten = nn.Flatten()
762+
self.fc1 = nn.Linear(in_features=4 * 6 * 6, out_features=5, bias=False)
763+
self.relu2 = nn.ReLU()
764+
765+
def forward(self, x):
766+
# In the hls4ml model, there should be a Transpose node on the input tensor before conv1
767+
x = self.conv1(x)
768+
x = self.relu1(x)
769+
x = self.flatten(x) # This should result in a Transpose node that we aim to remove
770+
x = self.fc1(x)
771+
x = self.relu2(x)
772+
return x
773+
774+
model = TestModel(tensor_rank=tensor_rank)
775+
if tensor_rank == 2:
776+
input_shape = (1, 8)
777+
input_tensor = torch.randn(10, 1, 8)
778+
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 1)).detach().numpy())
779+
else:
780+
input_shape = (1, 8, 8)
781+
input_tensor = torch.randn(10, 1, 8, 8)
782+
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 3, 1)).detach().numpy())
783+
784+
batch_input_shape = (None,) + input_shape
785+
config = config_from_pytorch_model(
786+
model,
787+
default_precision='ap_fixed<32,16>',
788+
inputs_channel_last=False, # Crucial for testing if the first Transpose was removed
789+
transpose_outputs=False,
790+
)
791+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_transpose_nop_{tensor_rank}d_{backend}_{io_type}')
792+
hls_model = convert_from_pytorch_model(
793+
model,
794+
batch_input_shape,
795+
hls_config=config,
796+
output_dir=output_dir,
797+
io_type=io_type,
798+
backend=backend,
799+
)
800+
801+
hls_model.compile()
802+
803+
# Test optimizers removed the two Transpose layers
804+
transpose_layers = [layer for layer in list(hls_model.get_layers()) if layer.class_name == 'Transpose']
805+
assert len(transpose_layers) == 0
806+
807+
# Test predictions match
808+
pytorch_prediction = model(input_tensor).detach().numpy().flatten()
809+
hls_prediction = hls_model.predict(hls_input).flatten()
810+
811+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)

0 commit comments

Comments
 (0)