diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py index f836f36ac..a097d4ede 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py @@ -90,7 +90,8 @@ def _update_mask(self, weights): """ sparsity = self._pruning_schedule(self._step_fn())[1] with tf.name_scope('pruning_ops'): - abs_weights = tf.math.abs(weights) + # abs_weights = tf.math.abs(weights) + abs_weights = tf.random.uniform(weights.shape, dtype=weights.dtype) k = tf.dtypes.cast( tf.math.maximum( tf.math.round(