Skip to content

Commit 6a6599f

Browse files
authored
Merge pull request #1 from julesmuhizi/qdbn_post_fold_quant
included weight transpose
2 parents 930a1e8 + 2107fd9 commit 6a6599f

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

qkeras/qdense_batchnorm.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,6 @@ def call(self, inputs, training=None):
159159
data_format="channels_last")
160160
else:
161161
bias = 0
162-
# If loaded from a ckpt, bias_quantizer is the ckpt value
163-
# Else if the layer is called for the first time, in this case bias
164-
# quantizer is None and we need to calculate bias quantizer
165-
# type according to accumulator type
166-
if self.bias_quantizer_internal is not None:
167-
q_bias = self.bias_quantizer_internal(bias)
168-
else:
169-
q_bias = bias
170162

171163
# begin batchnorm
172164
_ = self.batchnorm(qdense_outputs, training=bn_training)
@@ -205,7 +197,7 @@ def call(self, inputs, training=None):
205197
inv *= gamma
206198

207199
# fold bias with bn stats
208-
folded_bias = inv * (q_bias - new_mean) + beta
200+
folded_bias = inv * (bias - new_mean) + beta
209201

210202
elif self.folding_mode == "ema_stats_folding":
211203
# We always scale the weights with a correction factor to the long term
@@ -227,25 +219,31 @@ def call(self, inputs, training=None):
227219
batch_inv *= gamma
228220
folded_bias = tf_utils.smart_cond(
229221
bn_training,
230-
lambda: batch_inv * (q_bias - mean) + beta,
231-
lambda: mv_inv * (q_bias - moving_mean) + beta)
222+
lambda: batch_inv * (bias - mean) + beta,
223+
lambda: mv_inv * (bias - moving_mean) + beta)
232224
# moving stats is always used to fold kernel in tflite; before bn freeze
233225
# an additional correction factor will be applied to the conv2d output
234226
# end batchnorm
235227
inv = mv_inv
236228
else:
237229
assert ValueError
238230

231+
# wrap dense kernel with bn parameters
232+
folded_kernel = inv*kernel
239233
# quantize the folded kernel
240234
if self.kernel_quantizer is not None:
241-
q_kernel = self.kernel_quantizer_internal(kernel)
235+
q_folded_kernel = self.kernel_quantizer_internal(folded_kernel)
236+
else:
237+
q_folded_kernel = folded_kernel
238+
239+
#quantize the folded bias
240+
if self.bias_quantizer_internal is not None:
241+
q_folded_bias = self.bias_quantizer_internal(folded_bias)
242242
else:
243-
q_kernel = kernel
244-
# wrap qdense kernel with bn parameters
245-
folded_kernel = inv * q_kernel
243+
q_folded_bias = folded_bias
246244

247-
applied_kernel = folded_kernel
248-
applied_bias = folded_bias
245+
applied_kernel = q_folded_kernel
246+
applied_bias = q_folded_bias
249247

250248
#calculate qdense output using the quantized folded kernel
251249
folded_outputs = tf.keras.backend.dot(inputs, applied_kernel)
@@ -290,8 +288,9 @@ def get_quantization_config(self):
290288
"kernel_quantizer": str(self.kernel_quantizer_internal),
291289
"bias_quantizer": str(self.bias_quantizer_internal),
292290
}
293-
def get_quantizers(self):
294-
return self.quantizers
291+
292+
def get_quantizers(self):
293+
return self.quantizers
295294

296295
# def get_prunable_weights(self):
297296
# return [self.kernel]

0 commit comments

Comments
 (0)