From 6160696ed95b57d26fe4361181c9aa3607d291f9 Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Wed, 22 Jun 2022 16:46:28 -0700 Subject: [PATCH 1/6] WIP relu merge. Still need to make changes to --- .../vivado/passes/convolution_templates.py | 2 + .../backends/vivado/passes/core_templates.py | 2 + hls4ml/model/optimizer/passes/relu_merge.py | 48 ++++ .../vivado/nnet_utils/nnet_dense_resource.h | 265 +++++++++++++++++- 4 files changed, 310 insertions(+), 7 deletions(-) create mode 100644 hls4ml/model/optimizer/passes/relu_merge.py diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index d4ac2d5b0a..48616d8cd2 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -10,9 +10,11 @@ static const unsigned n_out = {n_out}; static const unsigned reuse_factor = {reuse}; static const unsigned strategy = nnet::{strategy}; + static const bool merged_relu = {merged_relu}; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; + typedef {out_t}:: value_type out_t; template using product = nnet::product::{product_type}; }};\n""" diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index f63c0f454d..4aea84cd4e 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -13,11 +13,13 @@ static const unsigned reuse_factor = {reuse}; static const unsigned n_zeros = {nzeros}; static const unsigned n_nonzeros = {nonzeros}; + static const bool merged_relu = {merged_relu}; static const bool store_weights_in_bram = false; typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; + typedef {out_t}:: value_type out_t; template using product = nnet::product::{product_type}; }};\n""" diff --git a/hls4ml/model/optimizer/passes/relu_merge.py b/hls4ml/model/optimizer/passes/relu_merge.py new file mode 100644 index 0000000000..9c98eaa714 --- /dev/null +++ b/hls4ml/model/optimizer/passes/relu_merge.py @@ -0,0 +1,48 @@ +from hls4ml.model.optimizer import OptimizerPass + +class MergeRelu(OptimizerPass): + def match(self, node): + supported_layers = ['Conv2D', 'Conv2DBatchnorm', 'Dense'] + is_match = node.get_input_node().__class__.__name__ in supported_layers + + # hls4ml names ReLU activations 'Activation' + is_match = is_match and (node.__class__.__name__ == 'Activation') + return is_match + + def transform(self, model, node): + # Merge ReLU and Convolution/Dense layer + previous_node = node.get_input_node() + previous_node.index = node.index + previous_node.set_merged_relu(True) # Turn on merged_relu flag for this Conv/Dense layer + if 'Conv2D' in previous_node.__class__.__name__: + if previous_node.get_attr('data_format') == 'channels_last': + shape = [previous_node.attributes['out_height'], previous_node.attributes['out_width'], previous_node.attributes['n_filt']] + dims = ['OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index), 'N_FILT_{}'.format(previous_node.index)] + else: + shape = [previous_node.attributes['n_filt'], previous_node.attributes['out_height'], previous_node.attributes['out_width']] + dims = ['N_FILT_{}'.format(previous_node.index), 'OUT_HEIGHT_{}'.format(previous_node.index), 'OUT_WIDTH_{}'.format(previous_node.index)] + activation_precision, _ = model.config.get_precision(node, var='result') + previous_node.add_output_variable(shape, dims, precision=activation_precision) + if not node.get_output_nodes(): + print("WARNING: {} is the output layer! No rewiring performed.".format(node.name)) + model.remove_node(node, rewire=False) + else: + model.remove_node(node, rewire=True) + return True + elif 'Dense' in previous_node.__class__.__name__: + shape = previous_node.get_input_variable().shape[:] + shape[-1] = previous_node.attributes['n_out'] + if len(shape) > 1: + dims = ['N_LAYER_{}_{}'.format(i, previous_node.index) for i in range(1, len(shape) + 1)] + else: + dims = ['N_LAYER_{}'.format(previous_node.index)] + print('shape: {}'.format(shape)) + print('dims: {}'.format(dims)) + activation_precision, _ = model.config.get_precision(node, var='result') + previous_node.add_output_variable(shape, dims, precision=activation_precision) + if not node.get_output_nodes(): + print("WARNING: {} is the output layer! No rewiring performed.".format(node.name)) + model.remove_node(node, rewire=False) + else: + model.remove_node(node, rewire=True) + return True \ No newline at end of file diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index c0e5d17591..1150bf1cea 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -263,6 +263,247 @@ void dense_resource_rf_gt_nin( } } +// Dense (with ReLU) +template +void dense_relu_resource_rf_leq_nin( + data_T data[CONFIG_T::n_in], + res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in*CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = CONFIG_T::reuse_factor; + const int multfactor = MIN(CONFIG_T::n_in,CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int multscale = multiplier_limit/CONFIG_T::n_out; + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((multiplier_limit == block_factor) && "This function is correct only for RF <= N_IN"); + + #pragma HLS function_instantiate variable=weights,biases + //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete + + InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + #pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t) biases[iacc]; + } + + ReuseLoop: + for (int ir = 0; ir < rufactor; ir++) { + #pragma HLS PIPELINE II=1 rewind + + int w_index = ir; + int in_index = ir; + int out_index = 0; + int acc_step = 0; + + MultLoop: + for (int im = 0; im < block_factor; im++) { + #pragma HLS UNROLL + + acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + + // Increment w_index + w_index += rufactor; + // Increment in_index + in_index += rufactor; + if (in_index >= nin) { + in_index = ir; + } + // Increment out_index + if (acc_step + 1 >= multscale) { + acc_step = 0; + out_index++; + } else { + acc_step++; + } + } + } + + // Cast to "res_t" type + Result: + for (int ires = 0; ires < CONFIG_T::n_out; ires++) { + #pragma HLS UNROLL + typename CONFIG_T::out_t act = cast(acc[ires]); + if (act > 0) res[ires] = act; + else res[ires] = 0; + } +} + +template +void dense_relu_resource_rf_gt_nin_rem0( + data_T data[CONFIG_T::n_in], + res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in*CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = MIN(CONFIG_T::reuse_factor, CONFIG_T::n_in * CONFIG_T::n_out); + const int multfactor = MIN(CONFIG_T::n_in,CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int multscale = multiplier_limit/CONFIG_T::n_out; + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((rufactor > nin && rufactor % nin == 0) && "This function is correct only for RF > N_IN && RF % N_IN == 0"); + + #pragma HLS function_instantiate variable=weights,biases + //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete + + InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + #pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t) biases[iacc]; + } + + int w_index; + int in_index = 0; + int out_index; + int outstep = 0; + const int outscale = rufactor / nin; + + int outidx[rufactor]; + IndexLoop: + for (int ir = 0; ir < rufactor; ir++) { + outidx[ir] = outstep; + if ((ir + 1) % nin == 0) { + outstep++; + } + } + + ReuseLoop: + for (int ir = 0; ir < rufactor; ir++) { + #pragma HLS PIPELINE II=1 rewind + + w_index = ir; + out_index = outidx[ir]/*outstep*/; + + MultLoop: + for (int im = 0; im < block_factor; im++) { + #pragma HLS UNROLL + acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + + w_index += rufactor; + if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds + out_index += outscale; + } + + in_index++; + if (in_index >= nin) { + in_index = 0; + //outstep++; // This causes a huge increase in scheduling and RTL generation times, hence the above workaround. + } + } + + // Cast to "res_t" type + Result: + for (int ires = 0; ires < CONFIG_T::n_out; ires++) { + #pragma HLS UNROLL + typename CONFIG_T::out_t act = cast(acc[ires]); + if (act > 0) res[ires] = act; + else res[ires] = 0; + } +} + +template +void dense_relu_resource_rf_gt_nin( + data_T data[CONFIG_T::n_in], + res_T res[CONFIG_T::n_out], + typename CONFIG_T::weight_t weights[CONFIG_T::n_in*CONFIG_T::n_out], + typename CONFIG_T::bias_t biases[CONFIG_T::n_out]) { + + const int rufactor = CONFIG_T::reuse_factor; + const int multfactor = MIN(CONFIG_T::n_in,CONFIG_T::reuse_factor); + const int multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, multfactor); + const int block_factor = DIV_ROUNDUP(CONFIG_T::n_in*CONFIG_T::n_out, CONFIG_T::reuse_factor); + const int multscale = multiplier_limit/CONFIG_T::n_out; + const int nin = CONFIG_T::n_in; + const int nout = CONFIG_T::n_out; + + assert((multiplier_limit % nout == 0 || rufactor >= nin) && "The current Reuse Factor is not allowed"); + assert((rufactor > nin) && "This function is correct only for RF > N_IN"); + + #pragma HLS function_instantiate variable=weights,biases + //#pragma HLS RESOURCE variable=weights core=RAM_2P_BRAM Commenting out the deisgnation HLS seems to choose correctly + #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor + #pragma HLS ARRAY_PARTITION variable=biases complete + + typename CONFIG_T::accum_t acc[CONFIG_T::n_out]; + #pragma HLS ARRAY_PARTITION variable=acc complete + + InitAccum: + for (int iacc = 0; iacc < nout; iacc++) { + #pragma HLS UNROLL + acc[iacc] = (typename CONFIG_T::accum_t) biases[iacc]; + } + + ReuseLoop: + for (int ir = 0; ir < rufactor; ir++) { + #pragma HLS PIPELINE II=1 rewind + typename CONFIG_T::accum_t tmpmult[block_factor]; + #pragma HLS ARRAY_PARTITION variable=tmpmult complete + + MultLoop: + for (int im = 0; im < block_factor; im++) { + #pragma HLS UNROLL + int w_index = ir + rufactor * im; + int in_index = w_index % nin; + if (w_index >= CONFIG_T::n_in*CONFIG_T::n_out) continue; // check out of bounds + tmpmult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); + } + + typename CONFIG_T::accum_t mult[multiplier_limit]; + #pragma HLS ARRAY_PARTITION variable=mult complete + + ResetMult: + for (int imult = 0; imult < multiplier_limit; imult++) { + #pragma HLS UNROLL + mult[imult] = 0; + } + + AccumLoop1: + for (int im = 0; im < block_factor; im++) { + #pragma HLS UNROLL + int w_index = ir + rufactor * im; + int out_index = w_index / multfactor; + if (out_index >= multiplier_limit) continue; // check out of bounds + mult[out_index] += tmpmult[im]; + } + + AccumLoop2: + for (int im = 0; im < multiplier_limit; im++) { + #pragma HLS UNROLL + //int out_index = im/multscale; // This is the general case + //acc[out_index] += mult[im]; + acc[im] += mult[im]; // If RF > N_IN then multiplier_limit == n_out + } + } + + // Cast to "res_t" type + Result: + for (int ires = 0; ires < CONFIG_T::n_out; ires++) { + #pragma HLS UNROLL + typename CONFIG_T::out_t act = cast(acc[ires]); + if (act > 0) res[ires] = act; + else res[ires] = 0; + } +} + + template void dense_resource( data_T data[CONFIG_T::n_in], @@ -272,13 +513,23 @@ void dense_resource( #pragma HLS INLINE region - if (CONFIG_T::reuse_factor <= CONFIG_T::n_in) { - dense_resource_rf_leq_nin(data, res, weights, biases); - } else if (CONFIG_T::reuse_factor % CONFIG_T::n_in == 0) { - dense_resource_rf_gt_nin_rem0(data, res, weights, biases); - } else { - dense_resource_rf_gt_nin(data, res, weights, biases); - } + if (CONFIG_T::merged_relu) { + if (CONFIG_T::reuse_factor <= CONFIG_T::n_in) { + dense_relu_resource_rf_leq_nin(data, res, weights, biases); + } else if (CONFIG_T::reuse_factor % CONFIG_T::n_in == 0) { + dense_relu_resource_rf_gt_nin_rem0(data, res, weights, biases); + } else { + dense_relu_resource_rf_gt_nin(data, res, weights, biases); + } + } else { + if (CONFIG_T::reuse_factor <= CONFIG_T::n_in) { + dense_resource_rf_leq_nin(data, res, weights, biases); + } else if (CONFIG_T::reuse_factor % CONFIG_T::n_in == 0) { + dense_resource_rf_gt_nin_rem0(data, res, weights, biases); + } else { + dense_resource_rf_gt_nin(data, res, weights, biases); + } + } } } From 6db0a46e652f23e6ccab6373583bbb605934763a Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Mon, 27 Jun 2022 14:05:55 -0700 Subject: [PATCH 2/6] add merged_relu params to conv and dense templates by retrieving them from their own Layer class --- .../backends/vivado/passes/convolution_templates.py | 2 ++ hls4ml/backends/vivado/passes/core_templates.py | 2 ++ hls4ml/model/layers.py | 11 +++++++++++ 3 files changed, 15 insertions(+) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 48616d8cd2..7a97a40942 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -141,6 +141,8 @@ def format(self, node): mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') mult_params['n_out'] = node.get_attr('n_filt') mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_params['merged_relu'] = "true" if self.get_merged_relu() else "false" + mult_params['out_t'] = self.intermediate_op.type.name mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 4aea84cd4e..59d22b69da 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -38,6 +38,8 @@ def format(self, node): params['nzeros'] = node.get_weights('weight').nzeros params['nonzeros'] = node.get_weights('weight').nonzeros params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + params['merged_relu'] = "true" if self.get_merged_relu() else "false" + params['out_t'] = self.get_output_variable().type.name return self.template.format(**params) diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index f821d08e90..b9c55945ec 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -61,6 +61,8 @@ def __init__(self, model, name, attributes, inputs, outputs=None): accum_t = NamedType(*reversed(self.model.config.get_precision(self, 'accum'))) self.set_attr('accum_t', accum_t) + self.merged_relu = False + layer_config = self.model.config.get_layer_config(self) for config_key, config_value in layer_config.items(): if config_key in self.attributes: @@ -234,6 +236,7 @@ def _default_config_params(self): params.update(self.attributes) params['iotype'] = self.model.config.get_config_value('IOType') params['reuse'] = self.get_attr('reuse_factor') + params['merged_relu'] = "true" if self.get_merged_relu() else "false" return params @@ -243,6 +246,12 @@ def get_layer_precision(self): precision[data_type.name] = data_type return precision + def get_merged_relu(self): + return self.merged_relu + + def set_merged_relu(self, merged_relu): + self.merged_relu = merged_relu # Bool flag to set merged_relu + def get_numbers_cpp(self): numbers = '' for k, v in self.get_output_variable().get_shape(): @@ -300,6 +309,7 @@ def initialize(self): else: dims = ['N_LAYER_{}'.format(self.index)] self.add_output_variable(shape, dims) + self.intermediate_op = self.get_output_variable() self.add_weights(quantizer=self.get_attr('weight_quantizer'), compression=self.model.config.get_compression(self)) self.add_bias(quantizer=self.get_attr('bias_quantizer')) @@ -416,6 +426,7 @@ def initialize(self): shape = [self.attributes['n_filt'], self.attributes['out_height'], self.attributes['out_width']] dims = ['N_FILT_{}'.format(self.index), 'OUT_HEIGHT_{}'.format(self.index), 'OUT_WIDTH_{}'.format(self.index)] self.add_output_variable(shape, dims) + self.intermediate_op = self.get_output_variable() self.add_weights(quantizer=self.get_attr('weight_quantizer')) self.add_bias(quantizer=self.get_attr('bias_quantizer')) From 56495326ed7677e90c00e6a8345cbf46a637dba6 Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Mon, 27 Jun 2022 17:33:26 -0700 Subject: [PATCH 3/6] WIP merge_relu does not catch the right layer ordering pattern because the relu layer class is not simply called Activation anymore. Need to fix --- hls4ml/backends/vivado/passes/convolution_templates.py | 4 ++-- hls4ml/backends/vivado/passes/core_templates.py | 4 ++-- hls4ml/model/optimizer/__init__.py | 4 ++-- .../optimizer/passes/{relu_merge.py => merge_relu.py} | 6 +++++- hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h | 9 ++++++--- 5 files changed, 17 insertions(+), 10 deletions(-) rename hls4ml/model/optimizer/passes/{relu_merge.py => merge_relu.py} (89%) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index 7a97a40942..bb6a7d1298 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -141,8 +141,8 @@ def format(self, node): mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_height') * node.get_attr('filt_width') mult_params['n_out'] = node.get_attr('n_filt') mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) - mult_params['merged_relu'] = "true" if self.get_merged_relu() else "false" - mult_params['out_t'] = self.intermediate_op.type.name + mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false" + mult_params['out_t'] = node.intermediate_op.type.name mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 59d22b69da..7f557cb150 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -38,8 +38,8 @@ def format(self, node): params['nzeros'] = node.get_weights('weight').nzeros params['nonzeros'] = node.get_weights('weight').nonzeros params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) - params['merged_relu'] = "true" if self.get_merged_relu() else "false" - params['out_t'] = self.get_output_variable().type.name + params['merged_relu'] = "true" if node.get_merged_relu() else "false" + params['out_t'] = node.get_output_variable().type.name return self.template.format(**params) diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index 0978da2f43..32dc2e6c61 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -14,10 +14,10 @@ try: import qkeras register_flow('convert', ['fuse_bias_add', 'remove_useless_transpose', 'output_rounding_saturation_mode', 'qkeras_factorize_alpha', 'extract_ternary_threshold', 'fuse_consecutive_batch_normalization']) # TODO Maybe not all QKeras optmizers belong here? - register_flow('optimize', ['eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat'], requires=['convert']) + register_flow('optimize', ['eliminate_linear_activation', 'fuse_consecutive_batch_normalization', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat', 'merge_relu'], requires=['convert']) except: register_flow('convert', ['fuse_bias_add', 'remove_useless_transpose']) - register_flow('optimize', ['eliminate_linear_activation', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat'], requires=['convert']) + register_flow('optimize', ['eliminate_linear_activation', 'fuse_batch_normalization', 'replace_multidimensional_dense_with_conv', 'set_precision_concat', 'merge_relu'], requires=['convert']) del opt_path del module_path diff --git a/hls4ml/model/optimizer/passes/relu_merge.py b/hls4ml/model/optimizer/passes/merge_relu.py similarity index 89% rename from hls4ml/model/optimizer/passes/relu_merge.py rename to hls4ml/model/optimizer/passes/merge_relu.py index 9c98eaa714..b0ebdf7640 100644 --- a/hls4ml/model/optimizer/passes/relu_merge.py +++ b/hls4ml/model/optimizer/passes/merge_relu.py @@ -5,8 +5,12 @@ def match(self, node): supported_layers = ['Conv2D', 'Conv2DBatchnorm', 'Dense'] is_match = node.get_input_node().__class__.__name__ in supported_layers - # hls4ml names ReLU activations 'Activation' + # hls4ml names ReLU activations 'Activation' TODO: Node class name isn't + # Activation anymore.. it can change and in our test case is called + # VivadoAcceleratorActivation is_match = is_match and (node.__class__.__name__ == 'Activation') + print(f"Node class name = {node.__class__.__name__}") + print(f"Does layer {node.__class__.__name__} match Relu merge? {is_match}") return is_match def transform(self, model, node): diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h index 1150bf1cea..c494dc82bd 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_dense_resource.h @@ -264,6 +264,7 @@ void dense_resource_rf_gt_nin( } // Dense (with ReLU) + template void dense_relu_resource_rf_leq_nin( data_T data[CONFIG_T::n_in], @@ -309,7 +310,8 @@ void dense_relu_resource_rf_leq_nin( for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); // Increment w_index w_index += rufactor; @@ -395,7 +397,8 @@ void dense_relu_resource_rf_gt_nin_rem0( MultLoop: for (int im = 0; im < block_factor; im++) { #pragma HLS UNROLL - acc[out_index] += CONFIG_T::template product::product(data[in_index], weights[w_index]); + acc[out_index] += static_cast( + CONFIG_T::template product::product(data[in_index], weights[w_index])); w_index += rufactor; if (w_index >= CONFIG_T::n_in * CONFIG_T::n_out) break; // check out of bounds @@ -463,7 +466,7 @@ void dense_relu_resource_rf_gt_nin( int w_index = ir + rufactor * im; int in_index = w_index % nin; if (w_index >= CONFIG_T::n_in*CONFIG_T::n_out) continue; // check out of bounds - tmpmult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); + tmpmult[im] = CONFIG_T::template product::product(data[in_index], weights[w_index]); } typename CONFIG_T::accum_t mult[multiplier_limit]; From 6f59e18ddcf2e5bcd714f7723182481b024cc251 Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Tue, 28 Jun 2022 16:52:36 -0700 Subject: [PATCH 4/6] Match supported merge relu layers by checking if it's a subclass. Fix layer index error --- hls4ml/model/optimizer/passes/merge_relu.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/hls4ml/model/optimizer/passes/merge_relu.py b/hls4ml/model/optimizer/passes/merge_relu.py index b0ebdf7640..8121af53f6 100644 --- a/hls4ml/model/optimizer/passes/merge_relu.py +++ b/hls4ml/model/optimizer/passes/merge_relu.py @@ -1,22 +1,18 @@ from hls4ml.model.optimizer import OptimizerPass +from hls4ml.model.layers import Activation, Dense, Conv2D, Conv2DBatchnorm class MergeRelu(OptimizerPass): def match(self, node): - supported_layers = ['Conv2D', 'Conv2DBatchnorm', 'Dense'] - is_match = node.get_input_node().__class__.__name__ in supported_layers + supported_layers = (Dense, Conv2D, Conv2DBatchnorm) - # hls4ml names ReLU activations 'Activation' TODO: Node class name isn't - # Activation anymore.. it can change and in our test case is called - # VivadoAcceleratorActivation - is_match = is_match and (node.__class__.__name__ == 'Activation') - print(f"Node class name = {node.__class__.__name__}") - print(f"Does layer {node.__class__.__name__} match Relu merge? {is_match}") + is_match = issubclass(node.get_input_node().__class__, supported_layers) + # ReLU layers are of class Activation + is_match = is_match and issubclass(node.__class__, Activation) return is_match def transform(self, model, node): # Merge ReLU and Convolution/Dense layer previous_node = node.get_input_node() - previous_node.index = node.index previous_node.set_merged_relu(True) # Turn on merged_relu flag for this Conv/Dense layer if 'Conv2D' in previous_node.__class__.__name__: if previous_node.get_attr('data_format') == 'channels_last': From 361553ecbadd4ef7b69f22c6f5450df4441dacdc Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Thu, 30 Jun 2022 16:51:06 -0700 Subject: [PATCH 5/6] attempt to further restrict the matching function for the relu merge pass. --- hls4ml/model/optimizer/passes/merge_relu.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/hls4ml/model/optimizer/passes/merge_relu.py b/hls4ml/model/optimizer/passes/merge_relu.py index 8121af53f6..c483692a6f 100644 --- a/hls4ml/model/optimizer/passes/merge_relu.py +++ b/hls4ml/model/optimizer/passes/merge_relu.py @@ -3,11 +3,21 @@ class MergeRelu(OptimizerPass): def match(self, node): - supported_layers = (Dense, Conv2D, Conv2DBatchnorm) + backends = ['VivadoAccelerator', 'Vivado'] + supported_layers = ['Dense', 'Conv2D', 'Conv2DBatchNorm'] + # By the time this optimization pass runs, the Layer nodes' class names + # have been prepended with the name of the backend, e.g., a Conv2D + # layer is renamed VivadoAcceleratorConv2D. Thus, we strip the backend + # name for more streamlined matching. + input_node_class = node.get_input_node().__class__.__name__ + curr_node_class = node.__class__.__name__ + for b in backends: + input_node_class = input_node_class.replace(b, '') + curr_node_class = curr_node_class.replace(b, '') - is_match = issubclass(node.get_input_node().__class__, supported_layers) + is_match = input_node_class in supported_layers # ReLU layers are of class Activation - is_match = is_match and issubclass(node.__class__, Activation) + is_match = is_match and (curr_node_class == 'Activation') return is_match def transform(self, model, node): From 347a6bd4c3a42078f248d5a83de6a403c5aa10ee Mon Sep 17 00:00:00 2001 From: Olivia Weng Date: Mon, 4 Jul 2022 16:02:10 -0700 Subject: [PATCH 6/6] WIP trying to resolve out_t issues with the mult configs --- .../backends/vivado/passes/convolution_templates.py | 13 +++++++++++-- hls4ml/backends/vivado/passes/core_templates.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/hls4ml/backends/vivado/passes/convolution_templates.py b/hls4ml/backends/vivado/passes/convolution_templates.py index bb6a7d1298..564097d75b 100644 --- a/hls4ml/backends/vivado/passes/convolution_templates.py +++ b/hls4ml/backends/vivado/passes/convolution_templates.py @@ -14,7 +14,7 @@ typedef {accum_t.name} accum_t; typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; - typedef {out_t}:: value_type out_t; + typedef {out_t} out_t; template using product = nnet::product::{product_type}; }};\n""" @@ -68,6 +68,8 @@ def format(self, node): mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') mult_params['n_out'] = node.get_attr('n_filt') mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) + mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false" + mult_params['out_t'] = node.get_output_variable().type.name mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config @@ -142,7 +144,14 @@ def format(self, node): mult_params['n_out'] = node.get_attr('n_filt') mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false" - mult_params['out_t'] = node.intermediate_op.type.name + print(f"My out_t Class = {type(node.intermediate_op.type)}") + # TODO: Need to figure out when to append ::value_type (when + # node.intermediate_op's type is nnet::array but how to get that from a + # layer class?) and when not to Try: I think only io_stream IOType uses + # PackedType (io_parallel does not). Could grab IOType from layer + # class?? Turns out this isn't all that's needed--unclear what else. + # Also might need to add relu merge into dense_latency.h + mult_params['out_t'] = node.intermediate_op.type.name + '::value_type' if node.model.config.get_config_value('IOType') == 'io_stream' else node.intermediate_op.type.name mult_config = self.mult_template.format(**mult_params) return mult_config + '\n' + conv_config diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index 7f557cb150..b8f7e4c89f 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -19,7 +19,7 @@ typedef {bias_t.name} bias_t; typedef {weight_t.name} weight_t; typedef {index_t.name} index_t; - typedef {out_t}:: value_type out_t; + typedef {out_t} out_t; template using product = nnet::product::{product_type}; }};\n"""