Skip to content

Commit a357b7a

Browse files
authored
Merge pull request #976 from vloncar/channels_last_flatten
Remove unnecessary transposes related to conversion to channels_last format
2 parents 1616caf + 295ba9f commit a357b7a

File tree

4 files changed

+166
-9
lines changed

4 files changed

+166
-9
lines changed

hls4ml/model/optimizer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
[
3636
'infer_precision_types',
3737
'channels_last_converter',
38+
'remove_transpose_before_flatten',
39+
'remove_nop_transpose',
40+
'remove_single_channel_transpose',
3841
'fuse_bias_add',
39-
'remove_useless_transpose',
4042
'expand_layer_group',
4143
'output_rounding_saturation_mode',
4244
'qkeras_factorize_alpha',

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# Based on https://github.com/fastmachinelearning/qonnx/blob/
33
# 12c96a3ded06beacab08e0f554e4ed014476c0aa/src/qonnx/transformation/channels_last.py
44

5-
from hls4ml.model.layers import Concatenate, Input, Reshape
5+
from hls4ml.model.layers import Concatenate, Dense, Input, Reshape, Transpose
66
from hls4ml.model.optimizer import OptimizerPass
7+
from hls4ml.model.types import WeightVariable
78

89

910
class ChannelsLastConverter(OptimizerPass):
@@ -133,3 +134,62 @@ def transform(self, model, node):
133134

134135
node.channels_last_converted = True
135136
return True
137+
138+
139+
class RemoveTransposeBeforeFlatten(OptimizerPass):
140+
'''After the channels last conversion, model may have a sequence: Transpose -> Flatten -> Dense.
141+
In this case we can remove the expensive transpose and instead transpose the weights of the Dense layer.'''
142+
143+
def match(self, node):
144+
if node.model.config.get_config_value('IOType') != 'io_parallel':
145+
return False
146+
147+
if hasattr(node, '_channels_last_keep_transpose') and node._channels_last_keep_transpose:
148+
return False
149+
150+
if isinstance(node, Reshape):
151+
input_node = node.get_input_node()
152+
output_nodes = node.get_output_nodes()
153+
if (
154+
len(node.get_attr('target_shape')) == 1
155+
and isinstance(input_node, Transpose)
156+
and len(output_nodes) == 1
157+
and isinstance(output_nodes[0], Dense)
158+
):
159+
return True
160+
161+
return False
162+
163+
def transform(self, model, node):
164+
transpose_node = node.get_input_node()
165+
dense_node = node.get_output_nodes()[0]
166+
input_shape = transpose_node.get_output_variable().shape
167+
168+
if len(input_shape) == 2: # Usually after Conv1D
169+
tran_axis = [1, 0, 2]
170+
elif len(input_shape) == 3: # Usually after Conv2D
171+
tran_axis = [1, 2, 0, 3]
172+
else: # In this case we bail
173+
node._channels_last_keep_transpose = True
174+
return False
175+
176+
weight_var = dense_node.get_weights('weight')
177+
# Transpose the weights to achieve the same computation with transposed input
178+
weight_data_t = weight_var.data.reshape(*input_shape, -1).transpose(*tran_axis)
179+
weight_data_t = weight_data_t.reshape(-1, weight_data_t.shape[-1])
180+
new_weight_var = WeightVariable(
181+
var_name=weight_var.name,
182+
type_name=weight_var.type.name,
183+
precision=weight_var.type.precision,
184+
quantizer=weight_var.quantizer,
185+
data=weight_data_t,
186+
index=dense_node.index,
187+
)
188+
189+
# Update the weight variable of the node
190+
dense_node.set_attr('weight', new_weight_var)
191+
192+
# Get rid of the Transpose node
193+
model.remove_node(transpose_node)
194+
195+
return True
Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
1-
from hls4ml.model.layers import Transpose
1+
from hls4ml.model.layers import Input, Transpose
22
from hls4ml.model.optimizer import OptimizerPass
33

44

5-
class RemoveUselessTranspose(OptimizerPass):
5+
class RemoveNopTranspose(OptimizerPass):
6+
"""
7+
Remove a transpose layer if it doesn't do anything to a 1D array. i.e, 1D input and perm = [0]
8+
"""
9+
610
def match(self, node):
711
is_match = isinstance(node, Transpose) and node.get_attr('perm') == [0] # Useless transpose
812
return is_match
913

1014
def transform(self, model, node):
11-
"""
12-
Remove a transpose layer if it doesn't do anything. i.e 1D input and perm = [0]
13-
"""
14-
print(f"Unnessary {node.name} in the model, optimizing ...")
15+
print(f'Unnecessary transpose node ({node.name}) detected, optimizing ...')
1516
if not node.get_output_nodes():
16-
print(f"WARNING: {node.name} is the output layer! No rewiring performed.")
17+
print(f'WARNING: {node.name} is the output layer! No rewiring performed.')
1718
model.remove_node(node, rewire=False) # Don't rewire if there is no output layer
1819
else:
1920
model.remove_node(node, rewire=True)
2021

2122
return True
23+
24+
25+
class RemoveSingleChannelTranspose(OptimizerPass):
26+
"""
27+
Remove transpose of inputs if the number of channels is 1 as for io_parallel this doesn't affect the array
28+
representation used
29+
"""
30+
31+
def match(self, node):
32+
if node.model.config.get_config_value('IOType') != 'io_parallel':
33+
return False
34+
35+
return (
36+
isinstance(node, Transpose)
37+
and isinstance(node.get_input_node(), Input)
38+
and node.get_input_variable().shape[0] == 1
39+
)
40+
41+
def transform(self, model, node):
42+
# Adjust the input shape and remove the Transpose node
43+
input_var = node.get_input_variable()
44+
input_var.shape.append(input_var.shape.pop(0))
45+
model.remove_node(node)
46+
47+
return True

test/pytest/test_pytorch_api.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,72 @@ def test_skipped_layers(backend, io_type):
740740
hls_prediction = hls_model.predict(hls_input).flatten()
741741

742742
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)
743+
744+
745+
@pytest.mark.parametrize('backend', ['Vivado', 'Quartus'])
746+
@pytest.mark.parametrize('io_type', ['io_parallel']) # Only io_parallel for now
747+
@pytest.mark.parametrize('tensor_rank', [2, 3])
748+
def test_remove_transpose(backend, io_type, tensor_rank):
749+
class TestModel(nn.Module):
750+
def __init__(self, tensor_rank):
751+
super().__init__()
752+
if tensor_rank == 2:
753+
self.conv1 = nn.Conv1d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
754+
self.relu1 = nn.ReLU()
755+
self.flatten = nn.Flatten()
756+
self.fc1 = nn.Linear(in_features=4 * 6, out_features=5, bias=False)
757+
self.relu2 = nn.ReLU()
758+
else:
759+
self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, bias=False)
760+
self.relu1 = nn.ReLU()
761+
self.flatten = nn.Flatten()
762+
self.fc1 = nn.Linear(in_features=4 * 6 * 6, out_features=5, bias=False)
763+
self.relu2 = nn.ReLU()
764+
765+
def forward(self, x):
766+
# In the hls4ml model, there should be a Transpose node on the input tensor before conv1
767+
x = self.conv1(x)
768+
x = self.relu1(x)
769+
x = self.flatten(x) # This should result in a Transpose node that we aim to remove
770+
x = self.fc1(x)
771+
x = self.relu2(x)
772+
return x
773+
774+
model = TestModel(tensor_rank=tensor_rank)
775+
if tensor_rank == 2:
776+
input_shape = (1, 8)
777+
input_tensor = torch.randn(10, 1, 8)
778+
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 1)).detach().numpy())
779+
else:
780+
input_shape = (1, 8, 8)
781+
input_tensor = torch.randn(10, 1, 8, 8)
782+
hls_input = np.ascontiguousarray(torch.permute(input_tensor, (0, 2, 3, 1)).detach().numpy())
783+
784+
batch_input_shape = (None,) + input_shape
785+
config = config_from_pytorch_model(
786+
model,
787+
default_precision='ap_fixed<32,16>',
788+
inputs_channel_last=False, # Crucial for testing if the first Transpose was removed
789+
transpose_outputs=False,
790+
)
791+
output_dir = str(test_root_path / f'hls4mlprj_pytorch_api_transpose_nop_{tensor_rank}d_{backend}_{io_type}')
792+
hls_model = convert_from_pytorch_model(
793+
model,
794+
batch_input_shape,
795+
hls_config=config,
796+
output_dir=output_dir,
797+
io_type=io_type,
798+
backend=backend,
799+
)
800+
801+
hls_model.compile()
802+
803+
# Test optimizers removed the two Transpose layers
804+
transpose_layers = [layer for layer in list(hls_model.get_layers()) if layer.class_name == 'Transpose']
805+
assert len(transpose_layers) == 0
806+
807+
# Test predictions match
808+
pytorch_prediction = model(input_tensor).detach().numpy().flatten()
809+
hls_prediction = hls_model.predict(hls_input).flatten()
810+
811+
np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0, atol=5e-2)

0 commit comments

Comments
 (0)