@@ -159,14 +159,6 @@ def call(self, inputs, training=None):
159
159
data_format = "channels_last" )
160
160
else :
161
161
bias = 0
162
- # If loaded from a ckpt, bias_quantizer is the ckpt value
163
- # Else if the layer is called for the first time, in this case bias
164
- # quantizer is None and we need to calculate bias quantizer
165
- # type according to accumulator type
166
- if self .bias_quantizer_internal is not None :
167
- q_bias = self .bias_quantizer_internal (bias )
168
- else :
169
- q_bias = bias
170
162
171
163
# begin batchnorm
172
164
_ = self .batchnorm (qdense_outputs , training = bn_training )
@@ -205,7 +197,7 @@ def call(self, inputs, training=None):
205
197
inv *= gamma
206
198
207
199
# fold bias with bn stats
208
- folded_bias = inv * (q_bias - new_mean ) + beta
200
+ folded_bias = inv * (bias - new_mean ) + beta
209
201
210
202
elif self .folding_mode == "ema_stats_folding" :
211
203
# We always scale the weights with a correction factor to the long term
@@ -227,25 +219,31 @@ def call(self, inputs, training=None):
227
219
batch_inv *= gamma
228
220
folded_bias = tf_utils .smart_cond (
229
221
bn_training ,
230
- lambda : batch_inv * (q_bias - mean ) + beta ,
231
- lambda : mv_inv * (q_bias - moving_mean ) + beta )
222
+ lambda : batch_inv * (bias - mean ) + beta ,
223
+ lambda : mv_inv * (bias - moving_mean ) + beta )
232
224
# moving stats is always used to fold kernel in tflite; before bn freeze
233
225
# an additional correction factor will be applied to the conv2d output
234
226
# end batchnorm
235
227
inv = mv_inv
236
228
else :
237
229
assert ValueError
238
230
231
+ # wrap dense kernel with bn parameters
232
+ folded_kernel = inv * kernel
239
233
# quantize the folded kernel
240
234
if self .kernel_quantizer is not None :
241
- q_kernel = self .kernel_quantizer_internal (kernel )
235
+ q_folded_kernel = self .kernel_quantizer_internal (folded_kernel )
236
+ else :
237
+ q_folded_kernel = folded_kernel
238
+
239
+ #quantize the folded bias
240
+ if self .bias_quantizer_internal is not None :
241
+ q_folded_bias = self .bias_quantizer_internal (folded_bias )
242
242
else :
243
- q_kernel = kernel
244
- # wrap qdense kernel with bn parameters
245
- folded_kernel = inv * q_kernel
243
+ q_folded_bias = folded_bias
246
244
247
- applied_kernel = folded_kernel
248
- applied_bias = folded_bias
245
+ applied_kernel = q_folded_kernel
246
+ applied_bias = q_folded_bias
249
247
250
248
#calculate qdense output using the quantized folded kernel
251
249
folded_outputs = tf .keras .backend .dot (inputs , applied_kernel )
@@ -290,8 +288,9 @@ def get_quantization_config(self):
290
288
"kernel_quantizer" : str (self .kernel_quantizer_internal ),
291
289
"bias_quantizer" : str (self .bias_quantizer_internal ),
292
290
}
293
- def get_quantizers (self ):
294
- return self .quantizers
291
+
292
+ def get_quantizers (self ):
293
+ return self .quantizers
295
294
296
295
# def get_prunable_weights(self):
297
296
# return [self.kernel]
0 commit comments