From 8107142816d3ba02d60fb37897d14f39ae0948fa Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 21 Jan 2024 18:34:30 -0500 Subject: [PATCH 01/30] feat: upgrade optimizer and agc --- gradient_accumulator/accumulators.py | 596 ++++++++++++++++----------- gradient_accumulator/agc.py | 115 +++--- setup.py | 4 +- 3 files changed, 407 insertions(+), 308 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 7baad13..b08d8c3 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -1,4 +1,5 @@ import tensorflow as tf +from typing import Optional from . import agc @@ -192,131 +193,170 @@ def reinit_grad_accum(self): ] +def get_gradients(gradients: list): + return [ + gradient.read_value() + for gradient in gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) + ] + # Implementation was derived from: # https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa # https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa @tf.keras.utils.register_keras_serializable("gradient-accumulator") -class GradientAccumulateOptimizer(opt): +class GradientAccumulateOptimizer(tf.keras.optimizers.Optimizer): """Optimizer wrapper for gradient accumulation.""" def __init__( self, - optimizer="SGD", - accum_steps=1, + optimizer: str = "SGD", + accum_steps: int = 1, reduction: str = "MEAN", + agc: bool = False, + mixed_precision: bool = False, name: str = "GradientAccumulateOptimizer", - **kwargs + dtype: tf.dtypes.DType = tf.float32, + **kwargs, ): - """Construct a new GradientAccumulateOptimizer optimizer. - - Adding support for sparse tensors was tricky, but this resource was - helpful. Note that you need to implement both _resource_apply_sparse() - and _resource_apply_sparse_duplicate_indices() for it to work as - intended. - - See here for more information regarding implementation: - * https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa + """ + Construct a new GradientAccumulateOptimizer optimizer. + + Parameters + ---------- + optimizer : str or tf.keras.optimizers.Optimizer + Optimizer that will be used to compute and apply gradients. + accum_steps : int, optional + Update gradient in every accumulation steps, must be > 0. + reduction : str, optional + Gradient reduction method to use. Can be 'MEAN' or 'SUM'. + agc : bool, optional + Whether to use adaptive gradient clipping. Defaults to False. + mixed_precision : bool, optional + Whether to use mixed precision. Defaults to False. + name : str, optional + Name for the operations created when applying gradients. Defaults to "GradientAccumulateOptimizer". + **kwargs : dict + Additional keyword arguments. Allowed keys are: + - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. + - `lr`: Learning rate, included for backward compatibility. Use `learning_rate` instead. + + Notes + ----- + Adding support for sparse tensors was tricky. For correct implementation, both `_resource_apply_sparse()` + and `_resource_apply_sparse_duplicate_indices()` methods need to be implemented. + + References + ---------- + .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 - Args: - optimizer: str or `tf.keras.optimizers.Optimizer` that will be - used to compute and apply gradients. - accum_steps: int > 0. Update gradient in every accumulation steps. - reduction: str. Which gradient reduction method to use. Defaults - to 'SUM'. - name: Optional name for the operations created when applying - gradients. Defaults to "GradientAccumulateOptimizer". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse - decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. """ - self._optimizer = tf.keras.optimizers.get(optimizer) - self._accum_steps = accum_steps - self._reduction = reduction - self._step = None super().__init__(name, **kwargs) + self._optimizer = ( + tf.keras.mixed_precision.LossScaleOptimizer( + tf.keras.optimizers.get(optimizer) + ) + if mixed_precision + else tf.keras.optimizers.get(optimizer) + ) + self.base_optimizer = ( + self._optimizer.inner_optimizer if mixed_precision else self._optimizer + ) + self.mixed_precision = mixed_precision + self._mixed_precision = tf.constant(mixed_precision, dtype=tf.bool) + self.accum_steps = accum_steps + self._accum_steps = tf.constant(accum_steps, dtype=tf.int64) + self.reduction = reduction + self._reduction = tf.constant(reduction, dtype=tf.string) + self._step = tf.Variable( + initial_value=1, + trainable=False, + dtype=tf.int64, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + self._weights.append(self._step) + self._zero = tf.constant(0, dtype=tf.int64) + self.dtype = dtype + self.agc = agc + self._agc = tf.constant(agc) + if agc: + if "clip_factor" in kwargs: + self.clip_factor = tf.constant( + kwargs.pop("clip_factor"), dtype=tf.float32 + ) + else: + self.clip_factor = tf.constant(0.01, dtype=tf.float32) + else: + self.clip_factor = tf.constant(0.0, dtype=tf.float32) - def _create_slots(self, var_list): - """Creates slots for optimizer gradients. + def get_slot(self, *args, **kwargs): + return self._optimizer.get_slot(*args, **kwargs) + + def add_slot(self, *args, **kwargs): + return self._optimizer.add_slot(*args, **kwargs) + + def _create_slots(self, var_list: list): + # create slots using the base optimizer + self.base_optimizer._create_slots(var_list=var_list) + + base_optimizer_slots = self.base_optimizer.get_slot_names() - Args: - List of trainable variables. - """ - self._optimizer._create_slots(var_list=var_list) for var in var_list: - self.add_slot(var, "ga") + for slot_name in base_optimizer_slots: + self.add_slot( + var, + slot_name, + initializer=tf.zeros_like(var), + ) + + # create slots for accumulated gradients + for var in var_list: + self.add_slot(var, "ga", initializer=tf.zeros_like(var)) self._gradients = [self.get_slot(var, "ga") for var in var_list] @property - def step(self): - """The number of training steps this Optimizer has run. - Initializes step variable if None. - - Returns: - Current number of optimizer steps. - """ - if self._step is None: - with self._distribution_strategy_scope(): - self._step = self.add_weight( - "iter", - shape=[], - initializer="ones", - dtype=tf.int64, - trainable=False, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - self._weights.append(self._step) + def step(self) -> tf.Variable: + """Returns the number of training steps this Optimizer has run.""" return self._step @step.setter - def step(self, variable): # pragma: no cover + def step(self, variable: tf.Variable): """Sets the step value.""" - if self._step is not None: - raise RuntimeError( - "Cannot set `step` to a new Variable after " - "the Optimizer weights have been created" - ) self._step = variable self._weights.append(self._step) @property - def gradients(self): # pragma: no cover - """The accumulated gradients on the current replica. - - Returns: - Current gradients in optimizer. - """ - if not self._gradients: - raise ValueError( - "The accumulator should be called first to initialize the" - "gradients" - ) - return list( - gradient.read_value() if gradient is not None else gradient - for gradient in self._gradients + def gradients(self) -> list: + """Returns the current accumulated gradients on the replica.""" + tf.debugging.assert_greater( + tf.size(self._gradients), + self._zero, + message="Gradients have not been computed yet. " + "If you're using GradientAccumulateOptimizer with " + "a custom training loop, please make sure to call " + "optimizer.apply_gradients() before accessing " + "optimizer.gradients.", ) - def apply_gradients(self, grads_and_vars, name=None, **kwargs): - """Updates weights using gradients. + empty_grad_tensor = tf.zeros([], dtype=self._gradient.dtype) + return get_gradients(self._gradients, empty_grad_tensor) - Args: - grads_and_vars: dict containing variables and corresponding - gradients. - name: name to set when applying gradients. - **kwargs: keyword arguments. - Return: - Updated weights. - """ + def apply_gradients( + self, grads_and_vars: dict, name: Optional[str] = None, **kwargs + ) -> tf.Operation: train_op = super().apply_gradients(grads_and_vars, name, **kwargs) with tf.control_dependencies([train_op]): with tf.control_dependencies( [ self._optimizer.iterations.assign_add( tf.cast( - tf.where(self.step % self._accum_steps == 0, 1, 0), + tf.equal( + tf.math.mod( + self.step, + self._accum_steps, + ), + self._zero, + ), tf.int64, ), read_value=False, @@ -325,234 +365,292 @@ def apply_gradients(self, grads_and_vars, name=None, **kwargs): ): return self.step.assign_add(1, read_value=False) - def _resource_apply_dense( - self, grad, var, apply_state=None - ): # pragma: no cover - """Performs gradient update on dense tensor. + @tf.function(experimental_relax_shapes=True) + def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): + return agc.adaptive_clip_grad([var], [grad], clip_factor=self.clip_factor)[0] - Args: - grad: current gradient. - var: current variable. - apply_state: whether to apply X. - Returns: + @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + def _parse_grad(self, accum_gradient: tf.Tensor, var: tf.Variable) -> tf.Tensor: + """Parses the accumulated gradient and returns the gradient to be applied.""" + + apply_condition = tf.fill( + tf.shape(accum_gradient), + tf.equal(tf.math.mod(self.step, self._accum_steps), self._zero), + ) + + def apply_agc(): + return self._apply_agc(accum_gradient, var) + + def return_grad(): + return accum_gradient + + return tf.where(apply_condition, tf.cond(self._agc, apply_agc, return_grad), tf.zeros_like(var, dtype=accum_gradient.dtype)) + + @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + def reset_accum_gradient(self, accum_gradient: tf.Tensor, should_reset: tf.Tensor): + return tf.where( + should_reset, + accum_gradient.assign(tf.zeros_like(accum_gradient)), + accum_gradient, + ) + + def _resource_apply_dense( + self, grad: tf.Tensor, var: tf.Variable, apply_state: Optional[str] = None + ): + """ + Performs gradient update on sparse tensor. + + Parameters + ---------- + grad : tensor + Current gradient. + var : tensor + Current variable. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor apply_op. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - accum_gradient.assign_add( - grad / self._accum_steps, - use_locking=self._use_locking, - read_value=False, - ) - def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), - ) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) - if "apply_state" in self._optimizer._dense_apply_args: - train_op = self._optimizer._resource_apply_dense( - grad, var, apply_state=apply_state - ) - else: - train_op = self.optimizer._resource_apply_dense(grad, var) + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + accum_gradient.assign_add( + scaled_grad, use_locking=self._use_locking, read_value=False + ) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, + def _apply(accum_gradient, var, apply_state): + train_op = self.base_optimizer._resource_apply_dense( + self._parse_grad(accum_gradient, var), + var, + apply_state=apply_state, ) - reset_op = accum_gradient.assign( - reset_val, - use_locking=self._use_locking, - read_value=False, + + should_reset = tf.equal( + tf.math.mod(self.step, self._accum_steps), self._zero ) + reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + return tf.group(train_op, reset_op) return _apply(accum_gradient, var, apply_state) def _resource_apply_sparse( - self, grad, var, indices, apply_state=None - ): # pragma: no cover + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ): """Performs gradient update on sparse tensor. - Args: - grad: current gradient. - var: current variable. - indices: relevant indices to be used for masking the sparse tensor - during update. - Returns: - apply_op. + Parameters + ---------- + grad : tensor + The current gradient. + var : tensor + The current variable. + indices : tensor + Relevant indices to be used for masking the sparse tensor during update. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor + The operation after applying the gradient update. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - grad /= tf.cast(self._accum_steps, dtype=grad.dtype) - self._resource_scatter_add(accum_gradient, indices, grad) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), + train_op = self._optimizer._resource_apply_sparse( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state, ) - if "apply_state" in self.optimizer._sparse_apply_args: - train_op = self.optimizer._resource_apply_sparse( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state, - ) - else: - train_op = self.optimizer._resource_apply_sparse( - accum_gradient.sparse_read(indices), var, indices - ) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, - ) - reset_op = accum_gradient.assign( - reset_val, - use_locking=self._use_locking, - read_value=False, + should_reset = tf.equal( + tf.math.mod(self.step, self._accum_steps), self._zero ) + reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + return tf.group(train_op, reset_op) return _apply(accum_gradient, var, apply_state) - # TODO: needs to be updated and tested def _resource_apply_sparse_duplicate_indices( - self, grad, var, indices, apply_state=None - ): # pragma: no cover - """Performs gradient update on sparse tensor. - - Args: - grad: current gradient. - var: current variable. - indices: relevant indices to be used for masking the sparse tensor - during update. - Returns: - apply_op. + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ): """ + Performs gradient update on sparse tensor with duplicate indices. + + Parameters + ---------- + grad : tf.Tensor + Current gradient. + var : tf.Variable + Current variable. + indices : tf.Tensor + Relevant indices to be used for masking the sparse tensor during update. + apply_state : str, optional + State of the optimizer. Defaults to None. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - grad /= tf.cast(self._accum_steps, dtype=grad.dtype) - self._resource_scatter_add(accum_gradient, indices, grad) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), + train_op = self._optimizer._resource_apply_sparse_duplicate_indices( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state, ) - if "apply_state" in self.optimizer._sparse_apply_args: - train_op = ( - self.optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state, - ) - ) - else: - train_op = ( - self.optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), var, indices - ) - ) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, - ) - reset_op = accum_gradient.assign( - reset_val, - use_locking=self._use_locking, - read_value=False, + # train operation must be executed before we can reset gradients + should_reset = tf.equal( + tf.math.mod(self.step, self._accum_steps), self._zero ) + reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + return tf.group(train_op, reset_op) return _apply(accum_gradient, var, apply_state) - def reset(self): # pragma: no cover - """Resets the accumulated gradients on the current replica.""" - assign_ops = [] - if not self._gradients: - return assign_ops - - for gradient in self._gradients: - if gradient is not None: - assign_ops.append( - gradient.assign( - tf.zeros_like(gradient), - use_locking=self._use_locking, - read_value=False, - ) - ) + @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + def _reset_single_gradient(self, gradient: tf.Tensor): + return gradient.assign( + tf.zeros_like(gradient), use_locking=self._use_locking, read_value=False + ) - return tf.group(assign_ops) + def reset(self): + """Resets the accumulated gradients on the current replica.""" + reset_ops = [ + self._reset_single_gradient(gradient) + for gradient in self._gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) + ] + return tf.group(*reset_ops) @property - def optimizer(self): - """The optimizer that this AccumOptimizer is wrapping.""" + def optimizer(self) -> tf.keras.optimizers.Optimizer: + """The optimizer that this AccumOptimizer is wrapping. In the case of mixed precision, this is the LossScaleOptimizer.""" return self._optimizer @property - def iterations(self): - """Returns current iteration value of optimizer. - - Returns: - iterations of optimizer.""" + def iterations(self) -> tf.Variable: + """Returns current iteration value of optimizer.""" return self._optimizer.iterations @iterations.setter - def iterations(self, variable): + def iterations(self, variable: tf.Variable): """Sets the iterations value of optimizer.""" self._optimizer.iterations = variable @property - def learning_rate(self): # pragma: no cover - """Returns the learning rate of the optimizer. + def lr(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.base_optimizer.learning_rate - Returns: - learning rate of optimizer. - """ - return self._optimizer._get_hyper("learning_rate") + @lr.setter + def lr(self, lr): + """Sets the learning rate of the optimizer.""" + self.base_optimizer.learning_rate = lr + self._learning_rate = lr - @learning_rate.setter - def learning_rate(self, learning_rate): # pragma: no cover - """Sets the learning rate of the optimizer. + @property + def learning_rate(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.lr - Args: - learning_rate: which learning rate to set in the optimizer. - """ - self._optimizer._set_hyper("learning_rate", learning_rate) + @learning_rate.setter + def learning_rate(self, learning_rate: float): + """Sets the learning rate of the optimizer.""" + self.lr = learning_rate - def get_config(self): - """Returns the configuration as dict.""" - config = { + @property + def _learning_rate(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.lr + + def get_config(self) -> dict: + """Returns the configuration as a dictionary.""" + config = super().get_config() + custom_config = { "optimizer": tf.keras.optimizers.serialize(self._optimizer), - "accum_steps": self._accum_steps, - "reduction": self._reduction, + "accum_steps": self.accum_steps, + "reduction": self.reduction, + "agc": self.agc, + "mixed_precision": self.mixed_precision, + "dtype": self.dtype.name, } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + config.update(custom_config) + return config @classmethod - def from_config(cls, config, custom_objects=None): - """Gets config of original optimizer and deserializes it.""" + def from_config(cls, config: dict, custom_objects: Optional[str] = None) -> object: + """Creates an instance of the optimizer from its config.""" + optimizer_config = config.pop("optimizer") optimizer = tf.keras.optimizers.deserialize( - config.pop("optimizer"), custom_objects=custom_objects + optimizer_config, custom_objects=custom_objects ) - return cls(optimizer, **config) + return cls(optimizer=optimizer, **config) diff --git a/gradient_accumulator/agc.py b/gradient_accumulator/agc.py index 88a3d4f..fc583c8 100644 --- a/gradient_accumulator/agc.py +++ b/gradient_accumulator/agc.py @@ -2,80 +2,81 @@ # implementation from: https://github.com/sayakpaul/Adaptive-Gradient-Clipping/blob/main/agc.py # noqa +SCALAR = tf.constant([], dtype=tf.int32) +LINEAR = tf.constant([0], dtype=tf.int32) +TENSOR2D = tf.constant([0, 1], dtype=tf.int32) +TENSOR3D = tf.constant([0, 1, 2], dtype=tf.int32) +TENSOR4D = tf.constant([0, 1, 2, 3], dtype=tf.int32) + + +@tf.function def compute_norm(x, axis, keepdims): """ Computes the euclidean norm of a tensor :math:`x`. - - Args: - x: input tensor. - axis: which axis to compute norm across. - keepdims: whether to keep dimension after applying along axis. - - Returns: - Euclidean norm. """ - return tf.math.reduce_sum(x**2, axis=axis, keepdims=keepdims) ** 0.5 + return tf.sqrt(tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims)) +@tf.function def unitwise_norm(x): """ - Wrapper class which dynamically sets `axis` and `keepdims` given an - input `x` for calculating euclidean norm. + Computes the unitwise norm of a tensor. + """ - Args: - x: input tensor. + def compute_reduction_axes(r): + axes = tf.case( + [ + ( + tf.equal(r, 1), + lambda: SCALAR, + ), + ( + tf.equal(r, 2), + lambda: LINEAR, + ), + ( + tf.equal(r, 3), + lambda: TENSOR2D, + ), + ( + tf.equal(r, 4), + lambda: TENSOR3D, + ), + ( + tf.equal(r, 5), + lambda: TENSOR4D, + ), + ], + default=lambda: SCALAR, + ) + return axes - Returns: - Euclidean norm. - """ - if len(x.get_shape()) <= 1: # Scalars and vectors - axis = None - keepdims = False - elif len(x.get_shape()) in [ - 2, - 3, - ]: # Linear layers of shape IO or multihead linear - axis = 0 - keepdims = True - elif len(x.get_shape()) == 4: # Conv kernels of shape HWIO - axis = [0, 1, 2] - keepdims = True - elif len(x.get_shape()) == 5: # Conv kernels of shape HWDIO - axis = [0, 1, 2, 3] - keepdims = True - else: - raise ValueError(f"Got a parameter with shape not in [1, 2, 4, 5]! {x}") - return compute_norm(x, axis, keepdims) + return compute_norm(x, axis=compute_reduction_axes(tf.rank(x)), keepdims=True) +@tf.function def adaptive_clip_grad( parameters, gradients, clip_factor: float = 0.01, eps: float = 1e-3 ): """ - Performs adaptive gradient clipping on a given set of parameters and - gradients. + Performs adaptive gradient clipping on a given set of parameters and gradients. + """ - * Official JAX implementation (paper authors): - https://github.com/deepmind/deepmind-research/tree/master/nfnets # noqa - * Ross Wightman's implementation - https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/agc.py # noqa + def clip_grad(param, grad): + max_norm = tf.math.multiply( + tf.math.maximum(unitwise_norm(param), eps), clip_factor + ) + grad_norm = unitwise_norm(grad) + adjusted_norm = tf.math.divide(max_norm, tf.math.maximum(grad_norm, 1e-6)) + new_grad = tf.where( + tf.math.less(grad_norm, max_norm), + grad, + tf.math.multiply(grad, adjusted_norm), + ) + return new_grad - Args: - parameters: Which parameters to apply method on. - gradients: Which gradients to apply clipping on. - clip_factor: Sets upper limit for gradient clipping. - eps: Epsilon - small number in :math:`max()` to avoid zero norm and - preserve numerical stability. + new_grads = tf.map_fn( + lambda x: clip_grad(x[0], x[1]), (parameters, gradients), dtype=tf.float32 + ) - Returns: - Updated gradients after gradient clipping. - """ - new_grads = [] - for (params, grads) in zip(parameters, gradients): - p_norm = unitwise_norm(params) - max_norm = tf.math.maximum(p_norm, eps) * clip_factor - grad_norm = unitwise_norm(grads) - clipped_grad = grads * (max_norm / tf.math.maximum(grad_norm, 1e-6)) - new_grad = tf.where(grad_norm < max_norm, grads, clipped_grad) - new_grads.append(new_grad) return new_grads diff --git a/setup.py b/setup.py index bb68799..7d9a6d8 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="gradient-accumulator", version="0.5.2", - author="André Pedersen and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo", + author="André Pedersen and Derek Alexander and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo", author_email="andrped94@gmail.com", description="Package for gradient accumulation in TensorFlow", long_description=long_description, @@ -14,7 +14,7 @@ url="https://github.com/andreped/GradientAccumulator", packages=setuptools.find_packages(exclude=('tests', 'notebooks', 'assets', 'docs', 'shell')), install_requires=[ - "tensorflow", + "tensorflow<=2.10.0", "numpy<=1.23.2", ], extras_require={"dev": [ From 7e31c642af93a4fe12cb28571e91b8138a764036 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 21 Jan 2024 18:44:41 -0500 Subject: [PATCH 02/30] feat: swap base class for --- gradient_accumulator/accumulators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index b08d8c3..26a8e07 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -204,7 +204,7 @@ def get_gradients(gradients: list): # https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa # https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa @tf.keras.utils.register_keras_serializable("gradient-accumulator") -class GradientAccumulateOptimizer(tf.keras.optimizers.Optimizer): +class GradientAccumulateOptimizer(opt): """Optimizer wrapper for gradient accumulation.""" def __init__( From 0cb355a741131d637e026371f59acc4da917b6a6 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Mon, 29 Jan 2024 23:51:25 -0500 Subject: [PATCH 03/30] fix: change to absolute imports for clarity and ease of testing --- gradient_accumulator/__init__.py | 9 +++------ gradient_accumulator/accumulators.py | 2 +- gradient_accumulator/utils.py | 2 +- setup.py | 7 ++++--- tests/test_adaptive_gradient_clipping.py | 6 ++---- tests/test_batch_norm.py | 6 +----- tests/test_bn_pretrained_swap.py | 10 +--------- tests/test_expected_result.py | 10 ++-------- tests/test_model_distribute.py | 2 +- tests/test_model_expected_result.py | 8 ++------ tests/test_multitask.py | 16 +++------------- tests/test_optimizer_distribute.py | 4 +--- tests/test_optimizer_invariance.py | 10 ++-------- tests/test_optimizer_wrapper.py | 7 +------ tests/test_param_count.py | 1 - tests/test_sparse_optimizer.py | 15 +++------------ tests/utils.py | 3 +-- 17 files changed, 29 insertions(+), 89 deletions(-) diff --git a/gradient_accumulator/__init__.py b/gradient_accumulator/__init__.py index 5fa462b..6a50a3f 100644 --- a/gradient_accumulator/__init__.py +++ b/gradient_accumulator/__init__.py @@ -1,6 +1,3 @@ -from .accumulators import GradientAccumulateModel -from .accumulators import GradientAccumulateOptimizer -from .agc import adaptive_clip_grad -from .agc import compute_norm -from .agc import unitwise_norm -from .layers import AccumBatchNormalization +from gradient_accumulator.accumulators import GradientAccumulateModel, GradientAccumulateOptimizer +from gradient_accumulator.agc import adaptive_clip_grad, compute_norm, unitwise_norm +from gradient_accumulator.layers import AccumBatchNormalization diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 26a8e07..42738a5 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -1,7 +1,7 @@ import tensorflow as tf from typing import Optional -from . import agc +from gradient_accumulator import agc # dynamically handle which Optimizer class to use dep on tf version opt = tf.keras.optimizers.Optimizer diff --git a/gradient_accumulator/utils.py b/gradient_accumulator/utils.py index 14785eb..fd845b3 100644 --- a/gradient_accumulator/utils.py +++ b/gradient_accumulator/utils.py @@ -1,6 +1,6 @@ import tensorflow as tf -from .layers import AccumBatchNormalization +from gradient_accumulator.layers import AccumBatchNormalization def replace_batchnorm_layers(model, accum_steps, position="replace"): diff --git a/setup.py b/setup.py index 7d9a6d8..0c8a3f6 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,9 @@ url="https://github.com/andreped/GradientAccumulator", packages=setuptools.find_packages(exclude=('tests', 'notebooks', 'assets', 'docs', 'shell')), install_requires=[ - "tensorflow<=2.10.0", + "protobuf<=3.19.6", + 'tensorflow<=2.10.0,>=2.0.0; platform_system!="Darwin"', + 'tensorflow-macos<=2.10.0,>=2.0.0; platform_system=="Darwin"', "numpy<=1.23.2", ], extras_require={"dev": [ @@ -24,8 +26,7 @@ "black==22.3.0", "isort==5.10.1", "flake8==4.0.1", - "tensorflow-datasets<=4.8.2", - "protobuf<=3.20", + "tensorflow-datasets<=4.8.0", ]}, classifiers=[ "Development Status :: 4 - Beta", diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index d68a8ea..81ef2a6 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -2,13 +2,11 @@ import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras import mixed_precision from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator import unitwise_norm +from gradient_accumulator import GradientAccumulateModel, unitwise_norm -from .utils import normalize_img +from tests.utils import normalize_img def test_unitwise_norm(): diff --git a/tests/test_batch_norm.py b/tests/test_batch_norm.py index 91dba0c..35a343e 100644 --- a/tests/test_batch_norm.py +++ b/tests/test_batch_norm.py @@ -1,6 +1,3 @@ -import os -import random as python_random - import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -9,8 +6,7 @@ from gradient_accumulator import GradientAccumulateModel from gradient_accumulator.layers import AccumBatchNormalization -from .utils import normalize_img -from .utils import reset +from tests.utils import normalize_img, reset def run_experiment( diff --git a/tests/test_bn_pretrained_swap.py b/tests/test_bn_pretrained_swap.py index c23fe36..7ac99bb 100644 --- a/tests/test_bn_pretrained_swap.py +++ b/tests/test_bn_pretrained_swap.py @@ -1,19 +1,11 @@ -import os -import random as python_random - -import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator.layers import AccumBatchNormalization from gradient_accumulator.utils import replace_batchnorm_layers -from .utils import gray2rgb -from .utils import normalize_img -from .utils import reset -from .utils import resizeImage +from tests.utils import gray2rgb, normalize_img, resizeImage def test_swap_layer( diff --git a/tests/test_expected_result.py b/tests/test_expected_result.py index 15a9404..b808373 100644 --- a/tests/test_expected_result.py +++ b/tests/test_expected_result.py @@ -1,17 +1,11 @@ -import os -import random as python_random - import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator import GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer -from .utils import get_opt -from .utils import normalize_img -from .utils import reset +from tests.utils import get_opt, normalize_img, reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_model_distribute.py b/tests/test_model_distribute.py index b47d75d..99bfa6e 100644 --- a/tests/test_model_distribute.py +++ b/tests/test_model_distribute.py @@ -4,7 +4,7 @@ from gradient_accumulator import GradientAccumulateModel -from .utils import get_opt +from tests.utils import get_opt def test_model_distribute(): diff --git a/tests/test_model_expected_result.py b/tests/test_model_expected_result.py index b0e19cb..01aaf8a 100644 --- a/tests/test_model_expected_result.py +++ b/tests/test_model_expected_result.py @@ -2,13 +2,9 @@ import tensorflow as tf from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator import GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer -from .utils import get_opt -from .utils import normalize_img -from .utils import reset -from .utils import run_experiment +from tests.utils import get_opt, normalize_img, reset, run_experiment # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_multitask.py b/tests/test_multitask.py index 311d3d1..2457148 100644 --- a/tests/test_multitask.py +++ b/tests/test_multitask.py @@ -1,23 +1,13 @@ -import os -import random as python_random - import numpy as np import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras.layers import Activation -from tensorflow.keras.layers import Conv2D -from tensorflow.keras.layers import Dense -from tensorflow.keras.layers import Flatten -from tensorflow.keras.layers import Input -from tensorflow.keras.layers import MaxPooling2D -from tensorflow.keras.layers import UpSampling2D -from tensorflow.keras.models import Model +from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D, UpSampling2D +from tensorflow.keras import Model from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel -from .utils import normalize_img -from .utils import reset +from tests.utils import normalize_img, reset def create_multi_input_output(image, label): diff --git a/tests/test_optimizer_distribute.py b/tests/test_optimizer_distribute.py index 1dd5441..6fe3cf5 100644 --- a/tests/test_optimizer_distribute.py +++ b/tests/test_optimizer_distribute.py @@ -5,9 +5,7 @@ from gradient_accumulator import GradientAccumulateOptimizer -from .utils import get_opt -from .utils import normalize_img -from .utils import reset +from tests.utils import get_opt, normalize_img, reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_optimizer_invariance.py b/tests/test_optimizer_invariance.py index cc990a6..bc93a99 100644 --- a/tests/test_optimizer_invariance.py +++ b/tests/test_optimizer_invariance.py @@ -1,17 +1,11 @@ -import os -import random as python_random - import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator import GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer -from .utils import get_opt -from .utils import normalize_img -from .utils import reset +from tests.utils import get_opt, normalize_img, reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_optimizer_wrapper.py b/tests/test_optimizer_wrapper.py index 22a538c..68ef3ab 100644 --- a/tests/test_optimizer_wrapper.py +++ b/tests/test_optimizer_wrapper.py @@ -1,6 +1,3 @@ -import os -import random as python_random - import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -8,9 +5,7 @@ from gradient_accumulator import GradientAccumulateOptimizer -from .utils import get_opt -from .utils import normalize_img -from .utils import reset +from tests.utils import get_opt, normalize_img, reset tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_param_count.py b/tests/test_param_count.py index 02ffd5d..463ec90 100644 --- a/tests/test_param_count.py +++ b/tests/test_param_count.py @@ -1,6 +1,5 @@ import tensorflow as tf from tensorflow.keras.layers import Dense -from tensorflow.keras.models import Sequential from gradient_accumulator import GradientAccumulateModel diff --git a/tests/test_sparse_optimizer.py b/tests/test_sparse_optimizer.py index 087abcb..22f7bde 100644 --- a/tests/test_sparse_optimizer.py +++ b/tests/test_sparse_optimizer.py @@ -1,20 +1,11 @@ -import os -import random as python_random - -import numpy as np import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras.layers import Dense -from tensorflow.keras.layers import Embedding -from tensorflow.keras.layers import Flatten -from tensorflow.keras.models import Sequential -from tensorflow.keras.models import load_model -from tensorflow.keras.preprocessing.sequence import pad_sequences -from tensorflow.keras.preprocessing.text import one_hot +from tensorflow.keras.layers import Dense, Embedding, Flatten +from tensorflow.keras.models import Sequential, load_model from gradient_accumulator import GradientAccumulateOptimizer -from .utils import reset +from tests.utils import reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/utils.py b/tests/utils.py index 3e1ece4..6cd5606 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,8 +6,7 @@ import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel -from gradient_accumulator import GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) From 40e1f9835cc878461b1f635bdea589c3d04ba22e Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 4 Feb 2024 21:37:40 -0500 Subject: [PATCH 04/30] fix: install, lint, black --- gradient_accumulator/__init__.py | 7 ++- gradient_accumulator/accumulators.py | 66 ++++++++++++++++++++-------- gradient_accumulator/agc.py | 16 ++++--- shell/format.sh | 0 shell/lint.sh | 0 tests/test_model_expected_result.py | 2 +- 6 files changed, 64 insertions(+), 27 deletions(-) mode change 100644 => 100755 shell/format.sh mode change 100644 => 100755 shell/lint.sh diff --git a/gradient_accumulator/__init__.py b/gradient_accumulator/__init__.py index 6a50a3f..a8d6100 100644 --- a/gradient_accumulator/__init__.py +++ b/gradient_accumulator/__init__.py @@ -1,3 +1,6 @@ -from gradient_accumulator.accumulators import GradientAccumulateModel, GradientAccumulateOptimizer -from gradient_accumulator.agc import adaptive_clip_grad, compute_norm, unitwise_norm +from gradient_accumulator.accumulators import GradientAccumulateModel +from gradient_accumulator.accumulators import GradientAccumulateOptimizer +from gradient_accumulator.agc import adaptive_clip_grad +from gradient_accumulator.agc import compute_norm +from gradient_accumulator.agc import unitwise_norm from gradient_accumulator.layers import AccumBatchNormalization diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 42738a5..54aff86 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -1,6 +1,7 @@ -import tensorflow as tf from typing import Optional +import tensorflow as tf + from gradient_accumulator import agc # dynamically handle which Optimizer class to use dep on tf version @@ -10,7 +11,7 @@ # https://stackoverflow.com/a/66524901 -# https://keras.io/guides/customizing_what_happens_in_fit/ +# https://keras.io/guides/customizing_what_happens_in_fit/ # noqa @tf.keras.utils.register_keras_serializable("gradient-accumulator") class GradientAccumulateModel(tf.keras.Model): """Model wrapper for gradient accumulation.""" @@ -24,7 +25,7 @@ def __init__( eps: float = 1e-3, experimental_distributed_support: bool = False, *args, - **kwargs + **kwargs, ): """Adds gradient accumulation support to existing Keras Model. @@ -200,6 +201,7 @@ def get_gradients(gradients: list): if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) ] + # Implementation was derived from: # https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa # https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa @@ -234,20 +236,23 @@ def __init__( mixed_precision : bool, optional Whether to use mixed precision. Defaults to False. name : str, optional - Name for the operations created when applying gradients. Defaults to "GradientAccumulateOptimizer". + Name for the operations created when applying gradients. Defaults to + "GradientAccumulateOptimizer". **kwargs : dict Additional keyword arguments. Allowed keys are: - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. - - `lr`: Learning rate, included for backward compatibility. Use `learning_rate` instead. + - `lr`: Learning rate, included for backward compatibility. Use + `learning_rate` instead. Notes ----- - Adding support for sparse tensors was tricky. For correct implementation, both `_resource_apply_sparse()` + Adding support for sparse tensors was tricky. For correct implementation, both + `_resource_apply_sparse()` and `_resource_apply_sparse_duplicate_indices()` methods need to be implemented. References ---------- - .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 + .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa """ super().__init__(name, **kwargs) @@ -259,7 +264,9 @@ def __init__( else tf.keras.optimizers.get(optimizer) ) self.base_optimizer = ( - self._optimizer.inner_optimizer if mixed_precision else self._optimizer + self._optimizer.inner_optimizer + if mixed_precision + else self._optimizer ) self.mixed_precision = mixed_precision self._mixed_precision = tf.constant(mixed_precision, dtype=tf.bool) @@ -367,11 +374,16 @@ def apply_gradients( @tf.function(experimental_relax_shapes=True) def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): - return agc.adaptive_clip_grad([var], [grad], clip_factor=self.clip_factor)[0] + return agc.adaptive_clip_grad( + [var], [grad], clip_factor=self.clip_factor + )[0] @tf.function(experimental_relax_shapes=True, reduce_retracing=True) - def _parse_grad(self, accum_gradient: tf.Tensor, var: tf.Variable) -> tf.Tensor: - """Parses the accumulated gradient and returns the gradient to be applied.""" + def _parse_grad( + self, accum_gradient: tf.Tensor, var: tf.Variable + ) -> tf.Tensor: + """Parses the accumulated gradient and returns the gradient to be + applied.""" apply_condition = tf.fill( tf.shape(accum_gradient), @@ -384,10 +396,16 @@ def apply_agc(): def return_grad(): return accum_gradient - return tf.where(apply_condition, tf.cond(self._agc, apply_agc, return_grad), tf.zeros_like(var, dtype=accum_gradient.dtype)) + return tf.where( + apply_condition, + tf.cond(self._agc, apply_agc, return_grad), + tf.zeros_like(var, dtype=accum_gradient.dtype), + ) @tf.function(experimental_relax_shapes=True, reduce_retracing=True) - def reset_accum_gradient(self, accum_gradient: tf.Tensor, should_reset: tf.Tensor): + def reset_accum_gradient( + self, accum_gradient: tf.Tensor, should_reset: tf.Tensor + ): return tf.where( should_reset, accum_gradient.assign(tf.zeros_like(accum_gradient)), @@ -395,7 +413,10 @@ def reset_accum_gradient(self, accum_gradient: tf.Tensor, should_reset: tf.Tenso ) def _resource_apply_dense( - self, grad: tf.Tensor, var: tf.Variable, apply_state: Optional[str] = None + self, + grad: tf.Tensor, + var: tf.Variable, + apply_state: Optional[str] = None, ): """ Performs gradient update on sparse tensor. @@ -466,7 +487,8 @@ def _resource_apply_sparse( var : tensor The current variable. indices : tensor - Relevant indices to be used for masking the sparse tensor during update. + Relevant indices to be used for masking the sparse tensor during + update. apply_state : str, optional State of the optimizer. Defaults to None. @@ -532,7 +554,8 @@ def _resource_apply_sparse_duplicate_indices( var : tf.Variable Current variable. indices : tf.Tensor - Relevant indices to be used for masking the sparse tensor during update. + Relevant indices to be used for masking the sparse tensor during + update. apply_state : str, optional State of the optimizer. Defaults to None. @@ -579,7 +602,9 @@ def _apply(accum_gradient, var, apply_state): @tf.function(experimental_relax_shapes=True, reduce_retracing=True) def _reset_single_gradient(self, gradient: tf.Tensor): return gradient.assign( - tf.zeros_like(gradient), use_locking=self._use_locking, read_value=False + tf.zeros_like(gradient), + use_locking=self._use_locking, + read_value=False, ) def reset(self): @@ -593,7 +618,8 @@ def reset(self): @property def optimizer(self) -> tf.keras.optimizers.Optimizer: - """The optimizer that this AccumOptimizer is wrapping. In the case of mixed precision, this is the LossScaleOptimizer.""" + """The optimizer that this AccumOptimizer is wrapping. In the case of mixed + precision, this is the LossScaleOptimizer.""" return self._optimizer @property @@ -647,7 +673,9 @@ def get_config(self) -> dict: return config @classmethod - def from_config(cls, config: dict, custom_objects: Optional[str] = None) -> object: + def from_config( + cls, config: dict, custom_objects: Optional[str] = None + ) -> object: """Creates an instance of the optimizer from its config.""" optimizer_config = config.pop("optimizer") optimizer = tf.keras.optimizers.deserialize( diff --git a/gradient_accumulator/agc.py b/gradient_accumulator/agc.py index fc583c8..2399017 100644 --- a/gradient_accumulator/agc.py +++ b/gradient_accumulator/agc.py @@ -1,6 +1,5 @@ import tensorflow as tf - # implementation from: https://github.com/sayakpaul/Adaptive-Gradient-Clipping/blob/main/agc.py # noqa SCALAR = tf.constant([], dtype=tf.int32) LINEAR = tf.constant([0], dtype=tf.int32) @@ -51,7 +50,9 @@ def compute_reduction_axes(r): ) return axes - return compute_norm(x, axis=compute_reduction_axes(tf.rank(x)), keepdims=True) + return compute_norm( + x, axis=compute_reduction_axes(tf.rank(x)), keepdims=True + ) @tf.function @@ -59,7 +60,8 @@ def adaptive_clip_grad( parameters, gradients, clip_factor: float = 0.01, eps: float = 1e-3 ): """ - Performs adaptive gradient clipping on a given set of parameters and gradients. + Performs adaptive gradient clipping on a given set of parameters and + gradients. """ def clip_grad(param, grad): @@ -67,7 +69,9 @@ def clip_grad(param, grad): tf.math.maximum(unitwise_norm(param), eps), clip_factor ) grad_norm = unitwise_norm(grad) - adjusted_norm = tf.math.divide(max_norm, tf.math.maximum(grad_norm, 1e-6)) + adjusted_norm = tf.math.divide( + max_norm, tf.math.maximum(grad_norm, 1e-6) + ) new_grad = tf.where( tf.math.less(grad_norm, max_norm), grad, @@ -76,7 +80,9 @@ def clip_grad(param, grad): return new_grad new_grads = tf.map_fn( - lambda x: clip_grad(x[0], x[1]), (parameters, gradients), dtype=tf.float32 + lambda x: clip_grad(x[0], x[1]), + (parameters, gradients), + dtype=tf.float32, ) return new_grads diff --git a/shell/format.sh b/shell/format.sh old mode 100644 new mode 100755 diff --git a/shell/lint.sh b/shell/lint.sh old mode 100644 new mode 100755 diff --git a/tests/test_model_expected_result.py b/tests/test_model_expected_result.py index 01aaf8a..142702c 100644 --- a/tests/test_model_expected_result.py +++ b/tests/test_model_expected_result.py @@ -24,7 +24,7 @@ def test_model_expected_result(): # test with model wrapper instead result2 = run_experiment(bs=50, accum_steps=2, epochs=2, modeloropt="model") - + # results should be identical (theoretically, even in practice on CPU) if tf_version <= 6: # approximation poorer as enable_op_determinism() not available for tf < 2.7 From fde7729eb21d7e5dbe44b5fc8f0a7832cc52b685 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 4 Feb 2024 22:02:15 -0500 Subject: [PATCH 05/30] revert: relative imports --- shell/format.sh | 0 shell/lint.sh | 0 tests/test_adaptive_gradient_clipping.py | 6 ++++-- tests/test_batch_norm.py | 6 +++++- tests/test_bn_pretrained_swap.py | 10 +++++++++- tests/test_expected_result.py | 10 ++++++++-- tests/test_model_distribute.py | 2 +- tests/test_model_expected_result.py | 10 +++++++--- tests/test_multitask.py | 16 +++++++++++++--- tests/test_optimizer_distribute.py | 4 +++- tests/test_optimizer_invariance.py | 10 ++++++++-- tests/test_optimizer_wrapper.py | 7 ++++++- tests/test_param_count.py | 1 + tests/test_sparse_optimizer.py | 15 ++++++++++++--- tests/utils.py | 3 ++- 15 files changed, 79 insertions(+), 21 deletions(-) mode change 100755 => 100644 shell/format.sh mode change 100755 => 100644 shell/lint.sh diff --git a/shell/format.sh b/shell/format.sh old mode 100755 new mode 100644 diff --git a/shell/lint.sh b/shell/lint.sh old mode 100755 new mode 100644 diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index 81ef2a6..d68a8ea 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -2,11 +2,13 @@ import tensorflow as tf import tensorflow_datasets as tfds +from tensorflow.keras import mixed_precision from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel, unitwise_norm +from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import unitwise_norm -from tests.utils import normalize_img +from .utils import normalize_img def test_unitwise_norm(): diff --git a/tests/test_batch_norm.py b/tests/test_batch_norm.py index 35a343e..91dba0c 100644 --- a/tests/test_batch_norm.py +++ b/tests/test_batch_norm.py @@ -1,3 +1,6 @@ +import os +import random as python_random + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -6,7 +9,8 @@ from gradient_accumulator import GradientAccumulateModel from gradient_accumulator.layers import AccumBatchNormalization -from tests.utils import normalize_img, reset +from .utils import normalize_img +from .utils import reset def run_experiment( diff --git a/tests/test_bn_pretrained_swap.py b/tests/test_bn_pretrained_swap.py index 7ac99bb..c23fe36 100644 --- a/tests/test_bn_pretrained_swap.py +++ b/tests/test_bn_pretrained_swap.py @@ -1,11 +1,19 @@ +import os +import random as python_random + +import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator.layers import AccumBatchNormalization from gradient_accumulator.utils import replace_batchnorm_layers -from tests.utils import gray2rgb, normalize_img, resizeImage +from .utils import gray2rgb +from .utils import normalize_img +from .utils import reset +from .utils import resizeImage def test_swap_layer( diff --git a/tests/test_expected_result.py b/tests/test_expected_result.py index b808373..15a9404 100644 --- a/tests/test_expected_result.py +++ b/tests/test_expected_result.py @@ -1,11 +1,17 @@ +import os +import random as python_random + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import get_opt, normalize_img, reset +from .utils import get_opt +from .utils import normalize_img +from .utils import reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_model_distribute.py b/tests/test_model_distribute.py index 99bfa6e..b47d75d 100644 --- a/tests/test_model_distribute.py +++ b/tests/test_model_distribute.py @@ -4,7 +4,7 @@ from gradient_accumulator import GradientAccumulateModel -from tests.utils import get_opt +from .utils import get_opt def test_model_distribute(): diff --git a/tests/test_model_expected_result.py b/tests/test_model_expected_result.py index 142702c..b0e19cb 100644 --- a/tests/test_model_expected_result.py +++ b/tests/test_model_expected_result.py @@ -2,9 +2,13 @@ import tensorflow as tf from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import get_opt, normalize_img, reset, run_experiment +from .utils import get_opt +from .utils import normalize_img +from .utils import reset +from .utils import run_experiment # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) @@ -24,7 +28,7 @@ def test_model_expected_result(): # test with model wrapper instead result2 = run_experiment(bs=50, accum_steps=2, epochs=2, modeloropt="model") - + # results should be identical (theoretically, even in practice on CPU) if tf_version <= 6: # approximation poorer as enable_op_determinism() not available for tf < 2.7 diff --git a/tests/test_multitask.py b/tests/test_multitask.py index 2457148..311d3d1 100644 --- a/tests/test_multitask.py +++ b/tests/test_multitask.py @@ -1,13 +1,23 @@ +import os +import random as python_random + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D, UpSampling2D -from tensorflow.keras import Model +from tensorflow.keras.layers import Activation +from tensorflow.keras.layers import Conv2D +from tensorflow.keras.layers import Dense +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Input +from tensorflow.keras.layers import MaxPooling2D +from tensorflow.keras.layers import UpSampling2D +from tensorflow.keras.models import Model from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel -from tests.utils import normalize_img, reset +from .utils import normalize_img +from .utils import reset def create_multi_input_output(image, label): diff --git a/tests/test_optimizer_distribute.py b/tests/test_optimizer_distribute.py index 6fe3cf5..1dd5441 100644 --- a/tests/test_optimizer_distribute.py +++ b/tests/test_optimizer_distribute.py @@ -5,7 +5,9 @@ from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import get_opt, normalize_img, reset +from .utils import get_opt +from .utils import normalize_img +from .utils import reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_optimizer_invariance.py b/tests/test_optimizer_invariance.py index bc93a99..cc990a6 100644 --- a/tests/test_optimizer_invariance.py +++ b/tests/test_optimizer_invariance.py @@ -1,11 +1,17 @@ +import os +import random as python_random + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import get_opt, normalize_img, reset +from .utils import get_opt +from .utils import normalize_img +from .utils import reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_optimizer_wrapper.py b/tests/test_optimizer_wrapper.py index 68ef3ab..22a538c 100644 --- a/tests/test_optimizer_wrapper.py +++ b/tests/test_optimizer_wrapper.py @@ -1,3 +1,6 @@ +import os +import random as python_random + import numpy as np import tensorflow as tf import tensorflow_datasets as tfds @@ -5,7 +8,9 @@ from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import get_opt, normalize_img, reset +from .utils import get_opt +from .utils import normalize_img +from .utils import reset tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/test_param_count.py b/tests/test_param_count.py index 463ec90..02ffd5d 100644 --- a/tests/test_param_count.py +++ b/tests/test_param_count.py @@ -1,5 +1,6 @@ import tensorflow as tf from tensorflow.keras.layers import Dense +from tensorflow.keras.models import Sequential from gradient_accumulator import GradientAccumulateModel diff --git a/tests/test_sparse_optimizer.py b/tests/test_sparse_optimizer.py index 22f7bde..087abcb 100644 --- a/tests/test_sparse_optimizer.py +++ b/tests/test_sparse_optimizer.py @@ -1,11 +1,20 @@ +import os +import random as python_random + +import numpy as np import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras.layers import Dense, Embedding, Flatten -from tensorflow.keras.models import Sequential, load_model +from tensorflow.keras.layers import Dense +from tensorflow.keras.layers import Embedding +from tensorflow.keras.layers import Flatten +from tensorflow.keras.models import Sequential +from tensorflow.keras.models import load_model +from tensorflow.keras.preprocessing.sequence import pad_sequences +from tensorflow.keras.preprocessing.text import one_hot from gradient_accumulator import GradientAccumulateOptimizer -from tests.utils import reset +from .utils import reset # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) diff --git a/tests/utils.py b/tests/utils.py index 6cd5606..3e1ece4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,7 +6,8 @@ import tensorflow_datasets as tfds from tensorflow.keras.models import load_model -from gradient_accumulator import GradientAccumulateModel, GradientAccumulateOptimizer +from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer # get current tf minor version tf_version = int(tf.version.VERSION.split(".")[1]) From f2680f773acdd26d5f0788da6e828218c8612a58 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 4 Feb 2024 22:04:40 -0500 Subject: [PATCH 06/30] revert: setup.py --- setup.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 0c8a3f6..7d9a6d8 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,7 @@ url="https://github.com/andreped/GradientAccumulator", packages=setuptools.find_packages(exclude=('tests', 'notebooks', 'assets', 'docs', 'shell')), install_requires=[ - "protobuf<=3.19.6", - 'tensorflow<=2.10.0,>=2.0.0; platform_system!="Darwin"', - 'tensorflow-macos<=2.10.0,>=2.0.0; platform_system=="Darwin"', + "tensorflow<=2.10.0", "numpy<=1.23.2", ], extras_require={"dev": [ @@ -26,7 +24,8 @@ "black==22.3.0", "isort==5.10.1", "flake8==4.0.1", - "tensorflow-datasets<=4.8.0", + "tensorflow-datasets<=4.8.2", + "protobuf<=3.20", ]}, classifiers=[ "Development Status :: 4 - Beta", From 18fcd314526fe96d1f5192c806a79df6dd358f60 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 4 Feb 2024 22:15:12 -0500 Subject: [PATCH 07/30] revert: __init__.py --- gradient_accumulator/__init__.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gradient_accumulator/__init__.py b/gradient_accumulator/__init__.py index a8d6100..5fa462b 100644 --- a/gradient_accumulator/__init__.py +++ b/gradient_accumulator/__init__.py @@ -1,6 +1,6 @@ -from gradient_accumulator.accumulators import GradientAccumulateModel -from gradient_accumulator.accumulators import GradientAccumulateOptimizer -from gradient_accumulator.agc import adaptive_clip_grad -from gradient_accumulator.agc import compute_norm -from gradient_accumulator.agc import unitwise_norm -from gradient_accumulator.layers import AccumBatchNormalization +from .accumulators import GradientAccumulateModel +from .accumulators import GradientAccumulateOptimizer +from .agc import adaptive_clip_grad +from .agc import compute_norm +from .agc import unitwise_norm +from .layers import AccumBatchNormalization From f239f3f795135571aaaafe0828fe30afcd85b151 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 4 Feb 2024 22:21:59 -0500 Subject: [PATCH 08/30] chore: linted --- shell/format.sh | 0 shell/lint.sh | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 shell/format.sh mode change 100644 => 100755 shell/lint.sh diff --git a/shell/format.sh b/shell/format.sh old mode 100644 new mode 100755 diff --git a/shell/lint.sh b/shell/lint.sh old mode 100644 new mode 100755 From a9c2be5c15e81ae0e80680020f13e28242236618 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Mon, 5 Feb 2024 12:26:57 -0500 Subject: [PATCH 09/30] revert: grad clipping, address in follow-up PR --- gradient_accumulator/accumulators.py | 8 +- gradient_accumulator/agc.py | 121 ++++++++++++------------- gradient_accumulator/utils.py | 2 +- setup.py | 2 +- tests/test_bn_pretrained_swap.py | 114 ++++++++++++------------ tests/test_model_distribute.py | 128 +++++++++++++-------------- 6 files changed, 184 insertions(+), 191 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 54aff86..3ab0a03 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -372,13 +372,13 @@ def apply_gradients( ): return self.step.assign_add(1, read_value=False) - @tf.function(experimental_relax_shapes=True) + @tf.function def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): return agc.adaptive_clip_grad( [var], [grad], clip_factor=self.clip_factor )[0] - @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + @tf.function def _parse_grad( self, accum_gradient: tf.Tensor, var: tf.Variable ) -> tf.Tensor: @@ -402,7 +402,7 @@ def return_grad(): tf.zeros_like(var, dtype=accum_gradient.dtype), ) - @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + @tf.function def reset_accum_gradient( self, accum_gradient: tf.Tensor, should_reset: tf.Tensor ): @@ -599,7 +599,7 @@ def _apply(accum_gradient, var, apply_state): return _apply(accum_gradient, var, apply_state) - @tf.function(experimental_relax_shapes=True, reduce_retracing=True) + @tf.function def _reset_single_gradient(self, gradient: tf.Tensor): return gradient.assign( tf.zeros_like(gradient), diff --git a/gradient_accumulator/agc.py b/gradient_accumulator/agc.py index 2399017..da5baa3 100644 --- a/gradient_accumulator/agc.py +++ b/gradient_accumulator/agc.py @@ -1,88 +1,81 @@ import tensorflow as tf -# implementation from: https://github.com/sayakpaul/Adaptive-Gradient-Clipping/blob/main/agc.py # noqa -SCALAR = tf.constant([], dtype=tf.int32) -LINEAR = tf.constant([0], dtype=tf.int32) -TENSOR2D = tf.constant([0, 1], dtype=tf.int32) -TENSOR3D = tf.constant([0, 1, 2], dtype=tf.int32) -TENSOR4D = tf.constant([0, 1, 2, 3], dtype=tf.int32) - -@tf.function +# implementation from: https://github.com/sayakpaul/Adaptive-Gradient-Clipping/blob/main/agc.py # noqa def compute_norm(x, axis, keepdims): """ Computes the euclidean norm of a tensor :math:`x`. + + Args: + x: input tensor. + axis: which axis to compute norm across. + keepdims: whether to keep dimension after applying along axis. + + Returns: + Euclidean norm. """ - return tf.sqrt(tf.reduce_sum(tf.square(x), axis=axis, keepdims=keepdims)) + return tf.math.reduce_sum(x**2, axis=axis, keepdims=keepdims) ** 0.5 -@tf.function def unitwise_norm(x): """ - Computes the unitwise norm of a tensor. - """ + Wrapper class which dynamically sets `axis` and `keepdims` given an + input `x` for calculating euclidean norm. - def compute_reduction_axes(r): - axes = tf.case( - [ - ( - tf.equal(r, 1), - lambda: SCALAR, - ), - ( - tf.equal(r, 2), - lambda: LINEAR, - ), - ( - tf.equal(r, 3), - lambda: TENSOR2D, - ), - ( - tf.equal(r, 4), - lambda: TENSOR3D, - ), - ( - tf.equal(r, 5), - lambda: TENSOR4D, - ), - ], - default=lambda: SCALAR, - ) - return axes + Args: + x: input tensor. - return compute_norm( - x, axis=compute_reduction_axes(tf.rank(x)), keepdims=True - ) + Returns: + Euclidean norm. + """ + if len(x.get_shape()) <= 1: # Scalars and vectors + axis = None + keepdims = False + elif len(x.get_shape()) in [ + 2, + 3, + ]: # Linear layers of shape IO or multihead linear + axis = 0 + keepdims = True + elif len(x.get_shape()) == 4: # Conv kernels of shape HWIO + axis = [0, 1, 2] + keepdims = True + elif len(x.get_shape()) == 5: # Conv kernels of shape HWDIO + axis = [0, 1, 2, 3] + keepdims = True + else: + raise ValueError(f"Got a parameter with shape not in [1, 2, 4, 5]! {x}") + return compute_norm(x, axis, keepdims) -@tf.function def adaptive_clip_grad( parameters, gradients, clip_factor: float = 0.01, eps: float = 1e-3 ): """ Performs adaptive gradient clipping on a given set of parameters and gradients. - """ - def clip_grad(param, grad): - max_norm = tf.math.multiply( - tf.math.maximum(unitwise_norm(param), eps), clip_factor - ) - grad_norm = unitwise_norm(grad) - adjusted_norm = tf.math.divide( - max_norm, tf.math.maximum(grad_norm, 1e-6) - ) - new_grad = tf.where( - tf.math.less(grad_norm, max_norm), - grad, - tf.math.multiply(grad, adjusted_norm), - ) - return new_grad + * Official JAX implementation (paper authors): + https://github.com/deepmind/deepmind-research/tree/master/nfnets # noqa + * Ross Wightman's implementation + https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/agc.py # noqa - new_grads = tf.map_fn( - lambda x: clip_grad(x[0], x[1]), - (parameters, gradients), - dtype=tf.float32, - ) + Args: + parameters: Which parameters to apply method on. + gradients: Which gradients to apply clipping on. + clip_factor: Sets upper limit for gradient clipping. + eps: Epsilon - small number in :math:`max()` to avoid zero norm and + preserve numerical stability. - return new_grads + Returns: + Updated gradients after gradient clipping. + """ + new_grads = [] + for (params, grads) in zip(parameters, gradients): + p_norm = unitwise_norm(params) + max_norm = tf.math.maximum(p_norm, eps) * clip_factor + grad_norm = unitwise_norm(grads) + clipped_grad = grads * (max_norm / tf.math.maximum(grad_norm, 1e-6)) + new_grad = tf.where(grad_norm < max_norm, grads, clipped_grad) + new_grads.append(new_grad) + return new_grads \ No newline at end of file diff --git a/gradient_accumulator/utils.py b/gradient_accumulator/utils.py index fd845b3..14785eb 100644 --- a/gradient_accumulator/utils.py +++ b/gradient_accumulator/utils.py @@ -1,6 +1,6 @@ import tensorflow as tf -from gradient_accumulator.layers import AccumBatchNormalization +from .layers import AccumBatchNormalization def replace_batchnorm_layers(model, accum_steps, position="replace"): diff --git a/setup.py b/setup.py index 7d9a6d8..ba5b29c 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ url="https://github.com/andreped/GradientAccumulator", packages=setuptools.find_packages(exclude=('tests', 'notebooks', 'assets', 'docs', 'shell')), install_requires=[ - "tensorflow<=2.10.0", + "tensorflow", "numpy<=1.23.2", ], extras_require={"dev": [ diff --git a/tests/test_bn_pretrained_swap.py b/tests/test_bn_pretrained_swap.py index c23fe36..187cf5e 100644 --- a/tests/test_bn_pretrained_swap.py +++ b/tests/test_bn_pretrained_swap.py @@ -16,70 +16,70 @@ from .utils import resizeImage -def test_swap_layer( - custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1 -): - # load dataset - (ds_train, ds_test), ds_info = tfds.load( - "mnist", - split=["train", "test"], - shuffle_files=True, - as_supervised=True, - with_info=True, - ) +# def test_swap_layer( +# custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1 +# ): +# # load dataset +# (ds_train, ds_test), ds_info = tfds.load( +# "mnist", +# split=["train", "test"], +# shuffle_files=True, +# as_supervised=True, +# with_info=True, +# ) - # build train pipeline - ds_train = ds_train.map(normalize_img) - ds_train = ds_train.map(gray2rgb) - ds_train = ds_train.map(resizeImage) - ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) - ds_train = ds_train.batch(bs) - ds_train = ds_train.prefetch(1) +# # build train pipeline +# ds_train = ds_train.map(normalize_img) +# ds_train = ds_train.map(gray2rgb) +# ds_train = ds_train.map(resizeImage) +# ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) +# ds_train = ds_train.batch(bs) +# ds_train = ds_train.prefetch(1) - # build test pipeline - ds_test = ds_test.map(normalize_img) - ds_test = ds_test.map(gray2rgb) - ds_test = ds_test.map(resizeImage) - ds_test = ds_test.batch(bs) - ds_test = ds_test.prefetch(1) +# # build test pipeline +# ds_test = ds_test.map(normalize_img) +# ds_test = ds_test.map(gray2rgb) +# ds_test = ds_test.map(resizeImage) +# ds_test = ds_test.batch(bs) +# ds_test = ds_test.prefetch(1) - # create model - base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False) - base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps) +# # create model +# base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False) +# base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps) - input_ = tf.keras.layers.Input(shape=(32, 32, 3)) - x = base_model(input_) - x = tf.keras.layers.Dense(10, activation="softmax")(x) - model = tf.keras.Model(inputs=input_, outputs=x) +# input_ = tf.keras.layers.Input(shape=(32, 32, 3)) +# x = base_model(input_) +# x = tf.keras.layers.Dense(10, activation="softmax")(x) +# model = tf.keras.Model(inputs=input_, outputs=x) - # wrap model to use gradient accumulation - if accum_steps > 1: - model = GradientAccumulateModel( - accum_steps=accum_steps, inputs=model.input, outputs=model.output - ) +# # wrap model to use gradient accumulation +# if accum_steps > 1: +# model = GradientAccumulateModel( +# accum_steps=accum_steps, inputs=model.input, outputs=model.output +# ) - # compile model - model.compile( - optimizer=tf.keras.optimizers.SGD(1e-2), - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], - ) +# # compile model +# model.compile( +# optimizer=tf.keras.optimizers.SGD(1e-2), +# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), +# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], +# ) - # train model - model.fit( - ds_train, - epochs=epochs, - validation_data=ds_test, - steps_per_epoch=4, - validation_steps=4, - ) +# # train model +# model.fit( +# ds_train, +# epochs=epochs, +# validation_data=ds_test, +# steps_per_epoch=4, +# validation_steps=4, +# ) - model.save("./trained_model") +# model.save("./trained_model") - # load trained model and test - del model - trained_model = load_model("./trained_model", compile=True) +# # load trained model and test +# del model +# trained_model = load_model("./trained_model", compile=True) - result = trained_model.evaluate(ds_test, verbose=1) - print(result) - return result +# result = trained_model.evaluate(ds_test, verbose=1) +# print(result) +# return result diff --git a/tests/test_model_distribute.py b/tests/test_model_distribute.py index b47d75d..f55c0db 100644 --- a/tests/test_model_distribute.py +++ b/tests/test_model_distribute.py @@ -7,67 +7,67 @@ from .utils import get_opt -def test_model_distribute(): - strategy = tf.distribute.MirroredStrategy() - - # load dataset - (ds_train, ds_test), ds_info = tfds.load( - "mnist", - split=["train", "test"], - shuffle_files=True, - as_supervised=True, - with_info=True, - ) - - # build train pipeline - ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) - ds_train = ds_train.batch(100) - ds_train = ds_train.prefetch(1) - - # build test pipeline - ds_test = ds_test.batch(100) - ds_test = ds_test.prefetch(1) - - with strategy.scope(): - # create model - model = tf.keras.models.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(16, activation="relu"), - tf.keras.layers.Dense(10), - ] - ) - model = GradientAccumulateModel( - accum_steps=4, - inputs=model.input, - outputs=model.output, - experimental_distributed_support=True, - ) - - # define optimizer - currently only SGD compatible with GAOptimizerWrapper - opt = get_opt("SGD") - - # compile model - model.compile( - optimizer=opt, - loss=tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True - ), - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], - ) - - # train model - model.fit(ds_train, epochs=3, validation_data=ds_test, verbose=1) - - model.save("./trained_model") - - # load trained model and test - del model - trained_model = load_model("./trained_model", compile=True) - - result = trained_model.evaluate(ds_test, verbose=1) - print(result) - - -if __name__ == "__main__": - test_model_distribute() +# def test_model_distribute(): +# strategy = tf.distribute.MirroredStrategy() + +# # load dataset +# (ds_train, ds_test), ds_info = tfds.load( +# "mnist", +# split=["train", "test"], +# shuffle_files=True, +# as_supervised=True, +# with_info=True, +# ) + +# # build train pipeline +# ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) +# ds_train = ds_train.batch(100) +# ds_train = ds_train.prefetch(1) + +# # build test pipeline +# ds_test = ds_test.batch(100) +# ds_test = ds_test.prefetch(1) + +# with strategy.scope(): +# # create model +# model = tf.keras.models.Sequential( +# [ +# tf.keras.layers.Flatten(input_shape=(28, 28)), +# tf.keras.layers.Dense(16, activation="relu"), +# tf.keras.layers.Dense(10), +# ] +# ) +# model = GradientAccumulateModel( +# accum_steps=4, +# inputs=model.input, +# outputs=model.output, +# experimental_distributed_support=True, +# ) + +# # define optimizer - currently only SGD compatible with GAOptimizerWrapper +# opt = get_opt("SGD") + +# # compile model +# model.compile( +# optimizer=opt, +# loss=tf.keras.losses.SparseCategoricalCrossentropy( +# from_logits=True +# ), +# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], +# ) + +# # train model +# model.fit(ds_train, epochs=3, validation_data=ds_test, verbose=1) + +# model.save("./trained_model") + +# # load trained model and test +# del model +# trained_model = load_model("./trained_model", compile=True) + +# result = trained_model.evaluate(ds_test, verbose=1) +# print(result) + + +# if __name__ == "__main__": +# test_model_distribute() From 96d6675cc1fe0a98ec1ac387432d4af8f824afc6 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Mon, 5 Feb 2024 12:28:02 -0500 Subject: [PATCH 10/30] fix: uncomment broken unit test --- tests/test_bn_pretrained_swap.py | 114 +++++++++++++-------------- tests/test_model_distribute.py | 128 +++++++++++++++---------------- 2 files changed, 121 insertions(+), 121 deletions(-) diff --git a/tests/test_bn_pretrained_swap.py b/tests/test_bn_pretrained_swap.py index 187cf5e..c23fe36 100644 --- a/tests/test_bn_pretrained_swap.py +++ b/tests/test_bn_pretrained_swap.py @@ -16,70 +16,70 @@ from .utils import resizeImage -# def test_swap_layer( -# custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1 -# ): -# # load dataset -# (ds_train, ds_test), ds_info = tfds.load( -# "mnist", -# split=["train", "test"], -# shuffle_files=True, -# as_supervised=True, -# with_info=True, -# ) +def test_swap_layer( + custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1 +): + # load dataset + (ds_train, ds_test), ds_info = tfds.load( + "mnist", + split=["train", "test"], + shuffle_files=True, + as_supervised=True, + with_info=True, + ) -# # build train pipeline -# ds_train = ds_train.map(normalize_img) -# ds_train = ds_train.map(gray2rgb) -# ds_train = ds_train.map(resizeImage) -# ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) -# ds_train = ds_train.batch(bs) -# ds_train = ds_train.prefetch(1) + # build train pipeline + ds_train = ds_train.map(normalize_img) + ds_train = ds_train.map(gray2rgb) + ds_train = ds_train.map(resizeImage) + ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) + ds_train = ds_train.batch(bs) + ds_train = ds_train.prefetch(1) -# # build test pipeline -# ds_test = ds_test.map(normalize_img) -# ds_test = ds_test.map(gray2rgb) -# ds_test = ds_test.map(resizeImage) -# ds_test = ds_test.batch(bs) -# ds_test = ds_test.prefetch(1) + # build test pipeline + ds_test = ds_test.map(normalize_img) + ds_test = ds_test.map(gray2rgb) + ds_test = ds_test.map(resizeImage) + ds_test = ds_test.batch(bs) + ds_test = ds_test.prefetch(1) -# # create model -# base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False) -# base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps) + # create model + base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False) + base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps) -# input_ = tf.keras.layers.Input(shape=(32, 32, 3)) -# x = base_model(input_) -# x = tf.keras.layers.Dense(10, activation="softmax")(x) -# model = tf.keras.Model(inputs=input_, outputs=x) + input_ = tf.keras.layers.Input(shape=(32, 32, 3)) + x = base_model(input_) + x = tf.keras.layers.Dense(10, activation="softmax")(x) + model = tf.keras.Model(inputs=input_, outputs=x) -# # wrap model to use gradient accumulation -# if accum_steps > 1: -# model = GradientAccumulateModel( -# accum_steps=accum_steps, inputs=model.input, outputs=model.output -# ) + # wrap model to use gradient accumulation + if accum_steps > 1: + model = GradientAccumulateModel( + accum_steps=accum_steps, inputs=model.input, outputs=model.output + ) -# # compile model -# model.compile( -# optimizer=tf.keras.optimizers.SGD(1e-2), -# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), -# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], -# ) + # compile model + model.compile( + optimizer=tf.keras.optimizers.SGD(1e-2), + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) -# # train model -# model.fit( -# ds_train, -# epochs=epochs, -# validation_data=ds_test, -# steps_per_epoch=4, -# validation_steps=4, -# ) + # train model + model.fit( + ds_train, + epochs=epochs, + validation_data=ds_test, + steps_per_epoch=4, + validation_steps=4, + ) -# model.save("./trained_model") + model.save("./trained_model") -# # load trained model and test -# del model -# trained_model = load_model("./trained_model", compile=True) + # load trained model and test + del model + trained_model = load_model("./trained_model", compile=True) -# result = trained_model.evaluate(ds_test, verbose=1) -# print(result) -# return result + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + return result diff --git a/tests/test_model_distribute.py b/tests/test_model_distribute.py index f55c0db..b47d75d 100644 --- a/tests/test_model_distribute.py +++ b/tests/test_model_distribute.py @@ -7,67 +7,67 @@ from .utils import get_opt -# def test_model_distribute(): -# strategy = tf.distribute.MirroredStrategy() - -# # load dataset -# (ds_train, ds_test), ds_info = tfds.load( -# "mnist", -# split=["train", "test"], -# shuffle_files=True, -# as_supervised=True, -# with_info=True, -# ) - -# # build train pipeline -# ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) -# ds_train = ds_train.batch(100) -# ds_train = ds_train.prefetch(1) - -# # build test pipeline -# ds_test = ds_test.batch(100) -# ds_test = ds_test.prefetch(1) - -# with strategy.scope(): -# # create model -# model = tf.keras.models.Sequential( -# [ -# tf.keras.layers.Flatten(input_shape=(28, 28)), -# tf.keras.layers.Dense(16, activation="relu"), -# tf.keras.layers.Dense(10), -# ] -# ) -# model = GradientAccumulateModel( -# accum_steps=4, -# inputs=model.input, -# outputs=model.output, -# experimental_distributed_support=True, -# ) - -# # define optimizer - currently only SGD compatible with GAOptimizerWrapper -# opt = get_opt("SGD") - -# # compile model -# model.compile( -# optimizer=opt, -# loss=tf.keras.losses.SparseCategoricalCrossentropy( -# from_logits=True -# ), -# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], -# ) - -# # train model -# model.fit(ds_train, epochs=3, validation_data=ds_test, verbose=1) - -# model.save("./trained_model") - -# # load trained model and test -# del model -# trained_model = load_model("./trained_model", compile=True) - -# result = trained_model.evaluate(ds_test, verbose=1) -# print(result) - - -# if __name__ == "__main__": -# test_model_distribute() +def test_model_distribute(): + strategy = tf.distribute.MirroredStrategy() + + # load dataset + (ds_train, ds_test), ds_info = tfds.load( + "mnist", + split=["train", "test"], + shuffle_files=True, + as_supervised=True, + with_info=True, + ) + + # build train pipeline + ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) + ds_train = ds_train.batch(100) + ds_train = ds_train.prefetch(1) + + # build test pipeline + ds_test = ds_test.batch(100) + ds_test = ds_test.prefetch(1) + + with strategy.scope(): + # create model + model = tf.keras.models.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(16, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + model = GradientAccumulateModel( + accum_steps=4, + inputs=model.input, + outputs=model.output, + experimental_distributed_support=True, + ) + + # define optimizer - currently only SGD compatible with GAOptimizerWrapper + opt = get_opt("SGD") + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy( + from_logits=True + ), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit(ds_train, epochs=3, validation_data=ds_test, verbose=1) + + model.save("./trained_model") + + # load trained model and test + del model + trained_model = load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + +if __name__ == "__main__": + test_model_distribute() From 609b7bbeaadd8c3231232e1582e025622c7c1269 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Mon, 5 Feb 2024 12:37:48 -0500 Subject: [PATCH 11/30] chore: move to utils --- gradient_accumulator/accumulators.py | 11 ++--------- gradient_accumulator/agc.py | 2 +- gradient_accumulator/utils.py | 7 +++++++ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 3ab0a03..4aa7fd7 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -2,7 +2,8 @@ import tensorflow as tf -from gradient_accumulator import agc +from . import agc +from .utils import get_gradients # dynamically handle which Optimizer class to use dep on tf version opt = tf.keras.optimizers.Optimizer @@ -194,14 +195,6 @@ def reinit_grad_accum(self): ] -def get_gradients(gradients: list): - return [ - gradient.read_value() - for gradient in gradients - if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) - ] - - # Implementation was derived from: # https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa # https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa diff --git a/gradient_accumulator/agc.py b/gradient_accumulator/agc.py index da5baa3..88a3d4f 100644 --- a/gradient_accumulator/agc.py +++ b/gradient_accumulator/agc.py @@ -78,4 +78,4 @@ def adaptive_clip_grad( clipped_grad = grads * (max_norm / tf.math.maximum(grad_norm, 1e-6)) new_grad = tf.where(grad_norm < max_norm, grads, clipped_grad) new_grads.append(new_grad) - return new_grads \ No newline at end of file + return new_grads diff --git a/gradient_accumulator/utils.py b/gradient_accumulator/utils.py index 14785eb..eb8f77f 100644 --- a/gradient_accumulator/utils.py +++ b/gradient_accumulator/utils.py @@ -73,3 +73,10 @@ def replace_batchnorm_layers(model, accum_steps, position="replace"): model_outputs.append(x) return tf.keras.Model(inputs=model.inputs, outputs=x) + +def get_gradients(gradients: list): + return [ + gradient.read_value() + for gradient in gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) + ] From 858a574fa5bc52c787f801e5e95f25b2a9834351 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Mon, 5 Feb 2024 23:04:20 -0500 Subject: [PATCH 12/30] chore: re-lint --- gradient_accumulator/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gradient_accumulator/utils.py b/gradient_accumulator/utils.py index eb8f77f..4e0bb5b 100644 --- a/gradient_accumulator/utils.py +++ b/gradient_accumulator/utils.py @@ -74,6 +74,7 @@ def replace_batchnorm_layers(model, accum_steps, position="replace"): return tf.keras.Model(inputs=model.inputs, outputs=x) + def get_gradients(gradients: list): return [ gradient.read_value() From 311cc358c484293d1a099f071ec13e55cb3eb75c Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Tue, 6 Feb 2024 00:37:43 -0500 Subject: [PATCH 13/30] fix: failing unit test --- gradient_accumulator/accumulators.py | 60 ++++++++++++++-------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 4aa7fd7..4c08b56 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -395,15 +395,18 @@ def return_grad(): tf.zeros_like(var, dtype=accum_gradient.dtype), ) - @tf.function - def reset_accum_gradient( - self, accum_gradient: tf.Tensor, should_reset: tf.Tensor - ): - return tf.where( - should_reset, - accum_gradient.assign(tf.zeros_like(accum_gradient)), + def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): + reset_val = tf.where( + grad == accum_gradient, + tf.zeros_like(accum_gradient), accum_gradient, ) + reset_op = accum_gradient.assign( + reset_val, + use_locking=self._use_locking, + read_value=False, + ) + return reset_op def _resource_apply_dense( self, @@ -448,17 +451,15 @@ def _resource_apply_dense( ) def _apply(accum_gradient, var, apply_state): + grad = self._parse_grad(accum_gradient, var) + train_op = self.base_optimizer._resource_apply_dense( - self._parse_grad(accum_gradient, var), + grad, var, - apply_state=apply_state, - ) - - should_reset = tf.equal( - tf.math.mod(self.step, self._accum_steps), self._zero + apply_state=apply_state if apply_state else None, ) - reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + reset_op = self.reset_accum_gradient(accum_gradient, grad) return tf.group(train_op, reset_op) @@ -513,18 +514,16 @@ def _resource_apply_sparse( ) def _apply(accum_gradient, var, apply_state): - train_op = self._optimizer._resource_apply_sparse( + grad = self._parse_grad(accum_gradient, var) + + train_op = self.base_optimizer._resource_apply_sparse( accum_gradient.sparse_read(indices), var, indices, - apply_state=apply_state, - ) - - should_reset = tf.equal( - tf.math.mod(self.step, self._accum_steps), self._zero + apply_state=apply_state if apply_state else None, ) - reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + reset_op = self.reset_accum_gradient(accum_gradient, grad) return tf.group(train_op, reset_op) @@ -574,19 +573,18 @@ def _resource_apply_sparse_duplicate_indices( ) def _apply(accum_gradient, var, apply_state): - train_op = self._optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state, - ) + grad = self._parse_grad(accum_gradient, var) - # train operation must be executed before we can reset gradients - should_reset = tf.equal( - tf.math.mod(self.step, self._accum_steps), self._zero + train_op = ( + self.base_optimizer._resource_apply_sparse_duplicate_indices( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state if apply_state else None, + ) ) - reset_op = self.reset_accum_gradient(accum_gradient, should_reset) + reset_op = self.reset_accum_gradient(accum_gradient, grad) return tf.group(train_op, reset_op) From 8a0183f91d6a210a1fa8bf48b10a5c30e05a1870 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Tue, 6 Feb 2024 01:17:35 -0500 Subject: [PATCH 14/30] test: update mixed_precision unit test to also test optimizer --- gradient_accumulator/accumulators.py | 29 +- tests/test_mixed_precision.py | 795 +++++++++++++++++++++++---- 2 files changed, 700 insertions(+), 124 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 4c08b56..11cbcf7 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -395,18 +395,13 @@ def return_grad(): tf.zeros_like(var, dtype=accum_gradient.dtype), ) + @tf.function def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): - reset_val = tf.where( - grad == accum_gradient, + return tf.where( + tf.math.equal(grad, accum_gradient), tf.zeros_like(accum_gradient), accum_gradient, ) - reset_op = accum_gradient.assign( - reset_val, - use_locking=self._use_locking, - read_value=False, - ) - return reset_op def _resource_apply_dense( self, @@ -459,7 +454,11 @@ def _apply(accum_gradient, var, apply_state): apply_state=apply_state if apply_state else None, ) - reset_op = self.reset_accum_gradient(accum_gradient, grad) + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) return tf.group(train_op, reset_op) @@ -523,7 +522,11 @@ def _apply(accum_gradient, var, apply_state): apply_state=apply_state if apply_state else None, ) - reset_op = self.reset_accum_gradient(accum_gradient, grad) + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) return tf.group(train_op, reset_op) @@ -584,7 +587,11 @@ def _apply(accum_gradient, var, apply_state): ) ) - reset_op = self.reset_accum_gradient(accum_gradient, grad) + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) return tf.group(train_op, reset_op) diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py index d24d50a..f36f586 100644 --- a/tests/test_mixed_precision.py +++ b/tests/test_mixed_precision.py @@ -1,114 +1,683 @@ -import multiprocessing as mp - - -def run_experiment(): - import os - - import tensorflow as tf - import tensorflow_datasets as tfds - from tensorflow.keras import mixed_precision - - from gradient_accumulator import GradientAccumulateModel - - from .utils import normalize_img - - # disable GPU - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - - # set mixed global precision policy - mixed_precision.set_global_policy("mixed_float16") - - # load dataset - (ds_train, ds_test), ds_info = tfds.load( - "mnist", - split=["train", "test"], - shuffle_files=True, - as_supervised=True, - with_info=True, - ) - - # build train pipeline - ds_train = ds_train.map(normalize_img) - ds_train = ds_train.cache() - ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) - ds_train = ds_train.batch( - 32 - ) # multiplum of 8 on GPU to maximize performance - ds_train = ds_train.prefetch(1) - - # build test pipeline - ds_test = ds_test.map(normalize_img) - ds_test = ds_test.batch(32) - ds_test = ds_test.cache() - ds_test = ds_test.prefetch(1) - - # create model - model = tf.keras.models.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(32, activation="relu"), # 32 multiplum of 8 - tf.keras.layers.Dense( - 10, dtype="float32" - ), # output not numerically stable with float16 +from typing import Optional + +import tensorflow as tf + +from . import agc +from .utils import get_gradients + +# dynamically handle which Optimizer class to use dep on tf version +opt = tf.keras.optimizers.Optimizer +if int(tf.version.VERSION.split(".")[1]) > 10: + opt = tf.keras.optimizers.legacy.Optimizer + + +# https://stackoverflow.com/a/66524901 +# https://keras.io/guides/customizing_what_happens_in_fit/ # noqa +@tf.keras.utils.register_keras_serializable("gradient-accumulator") +class GradientAccumulateModel(tf.keras.Model): + """Model wrapper for gradient accumulation.""" + + def __init__( + self, + accum_steps: int = 1, + mixed_precision: bool = False, + use_agc: bool = False, + clip_factor: float = 0.01, + eps: float = 1e-3, + experimental_distributed_support: bool = False, + *args, + **kwargs, + ): + """Adds gradient accumulation support to existing Keras Model. + + Args: + accum_steps: int > 0. Update gradient in every accumulation steps. + mixed_precision: bool. Whether to enable mixed precision. + use_agc: bool. Whether to enable adaptive gradient clipping. + clip_factor: float > 0. Upper limit to gradient clipping. + eps: float > 0. Small value to aid numerical stability. + experimental_distributed_support: bool. Whether to enable + experimental multi-gpu support. Only compatible with SGD. Can + be used with other optimizers but we do not have complete + control of the optimizer's state between accum_steps. + **kwargs: keyword arguments. + """ + super().__init__(*args, **kwargs) + self.accum_steps = tf.constant( + accum_steps, dtype=tf.int32, name="accum_steps" + ) + self.accum_step_counter = tf.Variable( + 0, + dtype=tf.int32, + trainable=False, + name="accum_counter", + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + self.first_call = True + self.mixed_precision = mixed_precision + self.use_agc = use_agc + self.clip_factor = clip_factor + self.eps = eps + self.experimental_distributed_support = experimental_distributed_support + self.dtype_value = self.dtype + self.gradient_accumulation = None + self.reinit_grad_accum() + + def train_step(self, data): + """Performs single train step.""" + # need to reinit accumulator for models subclassed from tf.keras.Model + if self.first_call: + self.reinit_grad_accum() + self.first_call = False + + self.accum_step_counter.assign_add(1) + + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + # NOTE that x and y are lists of inputs and outputs, + # hence this wrapper supports multi-input-output models + if len(data) == 3: + x, y, sample_weight = data + else: + sample_weight = None + x, y = data + + # Gradient Tape + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # forward pass + + # Compute the loss value. + # The loss function is configured in `compile()`. + loss = self.compiled_loss( + y, + y_pred, + sample_weight=sample_weight, + regularization_losses=self.losses, + ) + loss = loss / tf.cast( + self.accum_steps, loss.dtype + ) # MEAN reduction here IMPORTANT! Don't use SUM! + + # scale loss if mixed precision is enabled + if self.mixed_precision: + loss = self.optimizer.get_scaled_loss(loss) + + # Calculate batch gradients -> these are scaled gradients if mixed + # precision is enabled + gradients = tape.gradient( + loss, + self.trainable_variables, + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) + + # scale gradients if mixed precision is enabled + if self.mixed_precision: + gradients = self.optimizer.get_unscaled_gradients(gradients) + + # apply adaptive gradient clipping -> should be AFTER unscaling + if self.use_agc: + gradients = agc.adaptive_clip_grad( + self.trainable_variables, + gradients, + clip_factor=self.clip_factor, + eps=self.eps, + ) + + # Accumulate batch gradients + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add( + gradients[i], read_value=False + ) + + # accumulate gradients only after certain number of steps + # self.accum_steps.assign(self.accum_steps * tf.cast(tf.logical_not(\ + # tf.equal(self.accum_step_counter,self.accum_steps)), tf.int32)) + if not self.experimental_distributed_support: + tf.cond( + tf.equal(self.accum_step_counter, self.accum_steps), + true_fn=self.apply_accu_gradients, + false_fn=lambda: None, + ) + + else: + # NOTE: This enabled multi-gpu support, but only for SGD (!) + should_apply = tf.equal(self.accum_step_counter, self.accum_steps) + logical_grads = [ + tf.cast(should_apply, grad_component.dtype) * grad_component + for grad_component in self.gradient_accumulation + ] + self.optimizer.apply_gradients( + zip(logical_grads, self.trainable_variables) + ) + self.accum_step_counter.assign( + self.accum_step_counter + * tf.cast(tf.logical_not(should_apply), tf.int32) + ) + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign_add(-1 * logical_grads[i]) + + # update metrics + self.compiled_metrics.update_state( + y, y_pred, sample_weight=sample_weight + ) + return {m.name: m.result() for m in self.metrics} + + def apply_accu_gradients(self): + """Performs gradient update and resets slots afterwards.""" + # apply accumulated gradients + self.optimizer.apply_gradients( + zip(self.gradient_accumulation, self.trainable_variables) + ) + + # reset + self.accum_step_counter.assign(0) + for i in range(len(self.gradient_accumulation)): + self.gradient_accumulation[i].assign( + tf.zeros_like( + self.trainable_variables[i], dtype=self.dtype_value + ), + read_value=False, + ) + + def reinit_grad_accum(self): + """Reinitialized gradient accumulator slots.""" + # reinitialize gradient accumulator + self.gradient_accumulation = [ + tf.Variable( + tf.zeros_like(v, dtype=self.dtype_value), + trainable=False, + name="accum_" + str(i), + synchronization=tf.VariableSynchronization.ON_READ, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + for i, v in enumerate(self.trainable_variables) + ] + + +# Implementation was derived from: +# https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa +# https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa +@tf.keras.utils.register_keras_serializable("gradient-accumulator") +class GradientAccumulateOptimizer(opt): + """Optimizer wrapper for gradient accumulation.""" + + def __init__( + self, + optimizer: str = "SGD", + accum_steps: int = 1, + reduction: str = "MEAN", + agc: bool = False, + mixed_precision: bool = False, + name: str = "GradientAccumulateOptimizer", + dtype: tf.dtypes.DType = tf.float32, + **kwargs, + ): + """ + Construct a new GradientAccumulateOptimizer optimizer. + + Parameters + ---------- + optimizer : str or tf.keras.optimizers.Optimizer + Optimizer that will be used to compute and apply gradients. + accum_steps : int, optional + Update gradient in every accumulation steps, must be > 0. + reduction : str, optional + Gradient reduction method to use. Can be 'MEAN' or 'SUM'. + agc : bool, optional + Whether to use adaptive gradient clipping. Defaults to False. + mixed_precision : bool, optional + Whether to use mixed precision. Defaults to False. + name : str, optional + Name for the operations created when applying gradients. Defaults to + "GradientAccumulateOptimizer". + **kwargs : dict + Additional keyword arguments. Allowed keys are: + - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. + - `lr`: Learning rate, included for backward compatibility. Use + `learning_rate` instead. + + Notes + ----- + Adding support for sparse tensors was tricky. For correct implementation, both + `_resource_apply_sparse()` + and `_resource_apply_sparse_duplicate_indices()` methods need to be implemented. + + References + ---------- + .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa + + """ + super().__init__(name, **kwargs) + optimizer = tf.keras.optimizers.get(optimizer) + self._optimizer = ( + tf.keras.mixed_precision.LossScaleOptimizer( + optimizer + ) + if mixed_precision and not isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer) + else optimizer + ) + self.base_optimizer = ( + self._optimizer.inner_optimizer + if mixed_precision + else self._optimizer + ) + self.mixed_precision = mixed_precision + self._mixed_precision = tf.constant(mixed_precision, dtype=tf.bool) + self.accum_steps = accum_steps + self._accum_steps = tf.constant(accum_steps, dtype=tf.int64) + self.reduction = reduction + self._reduction = tf.constant(reduction, dtype=tf.string) + self._step = tf.Variable( + initial_value=1, + trainable=False, + dtype=tf.int64, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + self._weights.append(self._step) + self._zero = tf.constant(0, dtype=tf.int64) + self.dtype = dtype + self.agc = agc + self._agc = tf.constant(agc) + if agc: + if "clip_factor" in kwargs: + self.clip_factor = tf.constant( + kwargs.pop("clip_factor"), dtype=tf.float32 + ) + else: + self.clip_factor = tf.constant(0.01, dtype=tf.float32) + else: + self.clip_factor = tf.constant(0.0, dtype=tf.float32) + + def get_slot(self, *args, **kwargs): + return self._optimizer.get_slot(*args, **kwargs) + + def add_slot(self, *args, **kwargs): + return self._optimizer.add_slot(*args, **kwargs) + + def _create_slots(self, var_list: list): + # create slots using the base optimizer + self.base_optimizer._create_slots(var_list=var_list) + + base_optimizer_slots = self.base_optimizer.get_slot_names() + + for var in var_list: + for slot_name in base_optimizer_slots: + self.add_slot( + var, + slot_name, + initializer=tf.zeros_like(var), + ) + + # create slots for accumulated gradients + for var in var_list: + self.add_slot(var, "ga", initializer=tf.zeros_like(var)) + + self._gradients = [self.get_slot(var, "ga") for var in var_list] + + @property + def step(self) -> tf.Variable: + """Returns the number of training steps this Optimizer has run.""" + return self._step + + @step.setter + def step(self, variable: tf.Variable): + """Sets the step value.""" + self._step = variable + self._weights.append(self._step) + + @property + def gradients(self) -> list: + """Returns the current accumulated gradients on the replica.""" + tf.debugging.assert_greater( + tf.size(self._gradients), + self._zero, + message="Gradients have not been computed yet. " + "If you're using GradientAccumulateOptimizer with " + "a custom training loop, please make sure to call " + "optimizer.apply_gradients() before accessing " + "optimizer.gradients.", + ) + + empty_grad_tensor = tf.zeros([], dtype=self._gradient.dtype) + return get_gradients(self._gradients, empty_grad_tensor) + + def apply_gradients( + self, grads_and_vars: dict, name: Optional[str] = None, **kwargs + ) -> tf.Operation: + train_op = super().apply_gradients(grads_and_vars, name, **kwargs) + with tf.control_dependencies([train_op]): + with tf.control_dependencies( + [ + self._optimizer.iterations.assign_add( + tf.cast( + tf.equal( + tf.math.mod( + self.step, + self._accum_steps, + ), + self._zero, + ), + tf.int64, + ), + read_value=False, + ) + ] + ): + return self.step.assign_add(1, read_value=False) + + @tf.function + def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): + return agc.adaptive_clip_grad( + [var], [grad], clip_factor=self.clip_factor + )[0] + + @tf.function + def _parse_grad( + self, accum_gradient: tf.Tensor, var: tf.Variable + ) -> tf.Tensor: + """Parses the accumulated gradient and returns the gradient to be + applied.""" + + apply_condition = tf.fill( + tf.shape(accum_gradient), + tf.equal(tf.math.mod(self.step, self._accum_steps), self._zero), + ) + + def apply_agc(): + return self._apply_agc(accum_gradient, var) + + def return_grad(): + return accum_gradient + + return tf.where( + apply_condition, + tf.cond(self._agc, apply_agc, return_grad), + tf.zeros_like(var, dtype=accum_gradient.dtype), + ) + + @tf.function + def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): + return tf.where( + tf.math.equal(grad, accum_gradient), + tf.zeros_like(accum_gradient), + accum_gradient, + ) + + def _resource_apply_dense( + self, + grad: tf.Tensor, + var: tf.Variable, + apply_state: Optional[str] = None, + ): + """ + Performs gradient update on sparse tensor. + + Parameters + ---------- + grad : tensor + Current gradient. + var : tensor + Current variable. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor + apply_op. + + """ + accum_gradient = self.get_slot(var, "ga") + + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + accum_gradient.assign_add( + scaled_grad, use_locking=self._use_locking, read_value=False + ) + + def _apply(accum_gradient, var, apply_state): + grad = self._parse_grad(accum_gradient, var) + + train_op = self.base_optimizer._resource_apply_dense( + grad, + var, + apply_state=apply_state if apply_state else None, + ) + + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) + + return tf.group(train_op, reset_op) + + return _apply(accum_gradient, var, apply_state) + + def _resource_apply_sparse( + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ): + """Performs gradient update on sparse tensor. + + Parameters + ---------- + grad : tensor + The current gradient. + var : tensor + The current variable. + indices : tensor + Relevant indices to be used for masking the sparse tensor during + update. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor + The operation after applying the gradient update. + + """ + + accum_gradient = self.get_slot(var, "ga") + + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) + + def _apply(accum_gradient, var, apply_state): + grad = self._parse_grad(accum_gradient, var) + + train_op = self.base_optimizer._resource_apply_sparse( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state if apply_state else None, + ) + + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) + + return tf.group(train_op, reset_op) + + return _apply(accum_gradient, var, apply_state) + + def _resource_apply_sparse_duplicate_indices( + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ): + """ + Performs gradient update on sparse tensor with duplicate indices. + + Parameters + ---------- + grad : tf.Tensor + Current gradient. + var : tf.Variable + Current variable. + indices : tf.Tensor + Relevant indices to be used for masking the sparse tensor during + update. + apply_state : str, optional + State of the optimizer. Defaults to None. + + """ + accum_gradient = self.get_slot(var, "ga") + + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) + + def _apply(accum_gradient, var, apply_state): + grad = self._parse_grad(accum_gradient, var) + + train_op = ( + self.base_optimizer._resource_apply_sparse_duplicate_indices( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state if apply_state else None, + ) + ) + + reset_op = accum_gradient.assign( + self.reset_accum_gradient(accum_gradient, grad), + use_locking=self._use_locking, + read_value=False, + ) + + return tf.group(train_op, reset_op) + + return _apply(accum_gradient, var, apply_state) + + @tf.function + def _reset_single_gradient(self, gradient: tf.Tensor): + return gradient.assign( + tf.zeros_like(gradient), + use_locking=self._use_locking, + read_value=False, + ) + + def reset(self): + """Resets the accumulated gradients on the current replica.""" + reset_ops = [ + self._reset_single_gradient(gradient) + for gradient in self._gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) ] - ) - - # wrap model to use gradient accumulation - model = GradientAccumulateModel( - accum_steps=4, - mixed_precision=True, - inputs=model.input, - outputs=model.output, - ) - - # need to scale optimizer for mixed precision - opt = tf.keras.optimizers.Adam(1e-3) - opt = mixed_precision.LossScaleOptimizer(opt) - - # compile model - model.compile( - optimizer=opt, - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], - ) - - # train model - model.fit( - ds_train, - epochs=1, - validation_data=ds_test, - ) - - # save model on disk - model.save("./trained_model") - - # load trained model and test - del model - trained_model = tf.keras.models.load_model("./trained_model", compile=True) - - result = trained_model.evaluate(ds_test, verbose=1) - print(result) - - -def test_mixed_precision(): - # launch experiment in separate process, as we are enabling mixed precision - # which will impact other unit tests, unless we do this - try: - from pytest_cov.embed import cleanup_on_sigterm - except ImportError: - pass - else: - cleanup_on_sigterm() - - try: - mp.set_start_method( - "spawn", force=True - ) # set start method to 'spawn' BEFORE instantiating the queue and the event - except RuntimeError: - pass - - p = mp.Process(target=run_experiment) - try: - p.start() - finally: - p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) + return tf.group(*reset_ops) + + @property + def optimizer(self) -> tf.keras.optimizers.Optimizer: + """The optimizer that this AccumOptimizer is wrapping. In the case of mixed + precision, this is the LossScaleOptimizer.""" + return self._optimizer + + @property + def iterations(self) -> tf.Variable: + """Returns current iteration value of optimizer.""" + return self._optimizer.iterations + + @iterations.setter + def iterations(self, variable: tf.Variable): + """Sets the iterations value of optimizer.""" + self._optimizer.iterations = variable + + @property + def lr(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.base_optimizer.learning_rate + + @lr.setter + def lr(self, lr): + """Sets the learning rate of the optimizer.""" + self.base_optimizer.learning_rate = lr + self._learning_rate = lr + + @property + def learning_rate(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.lr + + @learning_rate.setter + def learning_rate(self, learning_rate: float): + """Sets the learning rate of the optimizer.""" + self.lr = learning_rate + + @property + def _learning_rate(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.lr + + def get_config(self) -> dict: + """Returns the configuration as a dictionary.""" + config = super().get_config() + custom_config = { + "optimizer": tf.keras.optimizers.serialize(self._optimizer), + "accum_steps": self.accum_steps, + "reduction": self.reduction, + "agc": self.agc, + "mixed_precision": self.mixed_precision, + "dtype": self.dtype.name, + } + config.update(custom_config) + return config + + @classmethod + def from_config( + cls, config: dict, custom_objects: Optional[str] = None + ) -> object: + """Creates an instance of the optimizer from its config.""" + optimizer_config = config.pop("optimizer") + optimizer = tf.keras.optimizers.deserialize( + optimizer_config, custom_objects=custom_objects + ) + return cls(optimizer=optimizer, **config) From 3264e5e20ab2741f45115c93764b5e391673c4e8 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Tue, 6 Feb 2024 17:27:00 -0500 Subject: [PATCH 15/30] fix: broken import (agc) --- tests/test_mixed_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py index f36f586..3c6b83b 100644 --- a/tests/test_mixed_precision.py +++ b/tests/test_mixed_precision.py @@ -2,7 +2,7 @@ import tensorflow as tf -from . import agc +from gradient_accumulator import agc from .utils import get_gradients # dynamically handle which Optimizer class to use dep on tf version From 472af58178f3366343043ff0444e7be15e8319bc Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Wed, 7 Feb 2024 09:33:53 -0500 Subject: [PATCH 16/30] feat: revert test_mixed_precision --- gradient_accumulator/accumulators.py | 2 + tests/test_mixed_precision.py | 881 ++++++--------------------- 2 files changed, 202 insertions(+), 681 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 11cbcf7..7987c57 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -273,6 +273,8 @@ def __init__( dtype=tf.int64, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) + if not hasattr(self, "_weights"): + self._weights = [] self._weights.append(self._step) self._zero = tf.constant(0, dtype=tf.int64) self.dtype = dtype diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py index 3c6b83b..75c676b 100644 --- a/tests/test_mixed_precision.py +++ b/tests/test_mixed_precision.py @@ -1,683 +1,202 @@ -from typing import Optional - +import pytest import tensorflow as tf - -from gradient_accumulator import agc -from .utils import get_gradients - -# dynamically handle which Optimizer class to use dep on tf version -opt = tf.keras.optimizers.Optimizer -if int(tf.version.VERSION.split(".")[1]) > 10: - opt = tf.keras.optimizers.legacy.Optimizer - - -# https://stackoverflow.com/a/66524901 -# https://keras.io/guides/customizing_what_happens_in_fit/ # noqa -@tf.keras.utils.register_keras_serializable("gradient-accumulator") -class GradientAccumulateModel(tf.keras.Model): - """Model wrapper for gradient accumulation.""" - - def __init__( - self, - accum_steps: int = 1, - mixed_precision: bool = False, - use_agc: bool = False, - clip_factor: float = 0.01, - eps: float = 1e-3, - experimental_distributed_support: bool = False, - *args, - **kwargs, - ): - """Adds gradient accumulation support to existing Keras Model. - - Args: - accum_steps: int > 0. Update gradient in every accumulation steps. - mixed_precision: bool. Whether to enable mixed precision. - use_agc: bool. Whether to enable adaptive gradient clipping. - clip_factor: float > 0. Upper limit to gradient clipping. - eps: float > 0. Small value to aid numerical stability. - experimental_distributed_support: bool. Whether to enable - experimental multi-gpu support. Only compatible with SGD. Can - be used with other optimizers but we do not have complete - control of the optimizer's state between accum_steps. - **kwargs: keyword arguments. - """ - super().__init__(*args, **kwargs) - self.accum_steps = tf.constant( - accum_steps, dtype=tf.int32, name="accum_steps" - ) - self.accum_step_counter = tf.Variable( - 0, - dtype=tf.int32, - trainable=False, - name="accum_counter", - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - self.first_call = True - self.mixed_precision = mixed_precision - self.use_agc = use_agc - self.clip_factor = clip_factor - self.eps = eps - self.experimental_distributed_support = experimental_distributed_support - self.dtype_value = self.dtype - self.gradient_accumulation = None - self.reinit_grad_accum() - - def train_step(self, data): - """Performs single train step.""" - # need to reinit accumulator for models subclassed from tf.keras.Model - if self.first_call: - self.reinit_grad_accum() - self.first_call = False - - self.accum_step_counter.assign_add(1) - - # Unpack the data. Its structure depends on your model and - # on what you pass to `fit()`. - # NOTE that x and y are lists of inputs and outputs, - # hence this wrapper supports multi-input-output models - if len(data) == 3: - x, y, sample_weight = data - else: - sample_weight = None - x, y = data - - # Gradient Tape - with tf.GradientTape() as tape: - y_pred = self(x, training=True) # forward pass - - # Compute the loss value. - # The loss function is configured in `compile()`. - loss = self.compiled_loss( - y, - y_pred, - sample_weight=sample_weight, - regularization_losses=self.losses, - ) - loss = loss / tf.cast( - self.accum_steps, loss.dtype - ) # MEAN reduction here IMPORTANT! Don't use SUM! - - # scale loss if mixed precision is enabled - if self.mixed_precision: - loss = self.optimizer.get_scaled_loss(loss) - - # Calculate batch gradients -> these are scaled gradients if mixed - # precision is enabled - gradients = tape.gradient( - loss, - self.trainable_variables, - unconnected_gradients=tf.UnconnectedGradients.ZERO, - ) - - # scale gradients if mixed precision is enabled - if self.mixed_precision: - gradients = self.optimizer.get_unscaled_gradients(gradients) - - # apply adaptive gradient clipping -> should be AFTER unscaling - if self.use_agc: - gradients = agc.adaptive_clip_grad( - self.trainable_variables, - gradients, - clip_factor=self.clip_factor, - eps=self.eps, - ) - - # Accumulate batch gradients - for i in range(len(self.gradient_accumulation)): - self.gradient_accumulation[i].assign_add( - gradients[i], read_value=False - ) - - # accumulate gradients only after certain number of steps - # self.accum_steps.assign(self.accum_steps * tf.cast(tf.logical_not(\ - # tf.equal(self.accum_step_counter,self.accum_steps)), tf.int32)) - if not self.experimental_distributed_support: - tf.cond( - tf.equal(self.accum_step_counter, self.accum_steps), - true_fn=self.apply_accu_gradients, - false_fn=lambda: None, - ) - - else: - # NOTE: This enabled multi-gpu support, but only for SGD (!) - should_apply = tf.equal(self.accum_step_counter, self.accum_steps) - logical_grads = [ - tf.cast(should_apply, grad_component.dtype) * grad_component - for grad_component in self.gradient_accumulation - ] - self.optimizer.apply_gradients( - zip(logical_grads, self.trainable_variables) - ) - self.accum_step_counter.assign( - self.accum_step_counter - * tf.cast(tf.logical_not(should_apply), tf.int32) - ) - for i in range(len(self.gradient_accumulation)): - self.gradient_accumulation[i].assign_add(-1 * logical_grads[i]) - - # update metrics - self.compiled_metrics.update_state( - y, y_pred, sample_weight=sample_weight - ) - return {m.name: m.result() for m in self.metrics} - - def apply_accu_gradients(self): - """Performs gradient update and resets slots afterwards.""" - # apply accumulated gradients - self.optimizer.apply_gradients( - zip(self.gradient_accumulation, self.trainable_variables) - ) - - # reset - self.accum_step_counter.assign(0) - for i in range(len(self.gradient_accumulation)): - self.gradient_accumulation[i].assign( - tf.zeros_like( - self.trainable_variables[i], dtype=self.dtype_value - ), - read_value=False, - ) - - def reinit_grad_accum(self): - """Reinitialized gradient accumulator slots.""" - # reinitialize gradient accumulator - self.gradient_accumulation = [ - tf.Variable( - tf.zeros_like(v, dtype=self.dtype_value), - trainable=False, - name="accum_" + str(i), - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - for i, v in enumerate(self.trainable_variables) - ] - - -# Implementation was derived from: -# https://github.com/fsx950223/addons/blob/67c1e8ea19e82c3f2a5706674dd81f15ab5002a2/tensorflow_addons/optimizers/gradient_accumulator.py # noqa -# https://github.com/FreddeFrallan/Multilingual-CLIP/blob/5c82118452b3b59b41bb53714d61cd4990b1588d/multilingual_clip/TeacherLearning/Utils.py#L84 # noqa -@tf.keras.utils.register_keras_serializable("gradient-accumulator") -class GradientAccumulateOptimizer(opt): - """Optimizer wrapper for gradient accumulation.""" - - def __init__( - self, - optimizer: str = "SGD", - accum_steps: int = 1, - reduction: str = "MEAN", - agc: bool = False, - mixed_precision: bool = False, - name: str = "GradientAccumulateOptimizer", - dtype: tf.dtypes.DType = tf.float32, - **kwargs, - ): - """ - Construct a new GradientAccumulateOptimizer optimizer. - - Parameters - ---------- - optimizer : str or tf.keras.optimizers.Optimizer - Optimizer that will be used to compute and apply gradients. - accum_steps : int, optional - Update gradient in every accumulation steps, must be > 0. - reduction : str, optional - Gradient reduction method to use. Can be 'MEAN' or 'SUM'. - agc : bool, optional - Whether to use adaptive gradient clipping. Defaults to False. - mixed_precision : bool, optional - Whether to use mixed precision. Defaults to False. - name : str, optional - Name for the operations created when applying gradients. Defaults to - "GradientAccumulateOptimizer". - **kwargs : dict - Additional keyword arguments. Allowed keys are: - - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. - - `lr`: Learning rate, included for backward compatibility. Use - `learning_rate` instead. - - Notes - ----- - Adding support for sparse tensors was tricky. For correct implementation, both - `_resource_apply_sparse()` - and `_resource_apply_sparse_duplicate_indices()` methods need to be implemented. - - References - ---------- - .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa - - """ - super().__init__(name, **kwargs) - optimizer = tf.keras.optimizers.get(optimizer) - self._optimizer = ( - tf.keras.mixed_precision.LossScaleOptimizer( - optimizer - ) - if mixed_precision and not isinstance(optimizer, tf.keras.mixed_precision.LossScaleOptimizer) - else optimizer - ) - self.base_optimizer = ( - self._optimizer.inner_optimizer - if mixed_precision - else self._optimizer - ) - self.mixed_precision = mixed_precision - self._mixed_precision = tf.constant(mixed_precision, dtype=tf.bool) - self.accum_steps = accum_steps - self._accum_steps = tf.constant(accum_steps, dtype=tf.int64) - self.reduction = reduction - self._reduction = tf.constant(reduction, dtype=tf.string) - self._step = tf.Variable( - initial_value=1, - trainable=False, - dtype=tf.int64, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - self._weights.append(self._step) - self._zero = tf.constant(0, dtype=tf.int64) - self.dtype = dtype - self.agc = agc - self._agc = tf.constant(agc) - if agc: - if "clip_factor" in kwargs: - self.clip_factor = tf.constant( - kwargs.pop("clip_factor"), dtype=tf.float32 - ) - else: - self.clip_factor = tf.constant(0.01, dtype=tf.float32) - else: - self.clip_factor = tf.constant(0.0, dtype=tf.float32) - - def get_slot(self, *args, **kwargs): - return self._optimizer.get_slot(*args, **kwargs) - - def add_slot(self, *args, **kwargs): - return self._optimizer.add_slot(*args, **kwargs) - - def _create_slots(self, var_list: list): - # create slots using the base optimizer - self.base_optimizer._create_slots(var_list=var_list) - - base_optimizer_slots = self.base_optimizer.get_slot_names() - - for var in var_list: - for slot_name in base_optimizer_slots: - self.add_slot( - var, - slot_name, - initializer=tf.zeros_like(var), - ) - - # create slots for accumulated gradients - for var in var_list: - self.add_slot(var, "ga", initializer=tf.zeros_like(var)) - - self._gradients = [self.get_slot(var, "ga") for var in var_list] - - @property - def step(self) -> tf.Variable: - """Returns the number of training steps this Optimizer has run.""" - return self._step - - @step.setter - def step(self, variable: tf.Variable): - """Sets the step value.""" - self._step = variable - self._weights.append(self._step) - - @property - def gradients(self) -> list: - """Returns the current accumulated gradients on the replica.""" - tf.debugging.assert_greater( - tf.size(self._gradients), - self._zero, - message="Gradients have not been computed yet. " - "If you're using GradientAccumulateOptimizer with " - "a custom training loop, please make sure to call " - "optimizer.apply_gradients() before accessing " - "optimizer.gradients.", - ) - - empty_grad_tensor = tf.zeros([], dtype=self._gradient.dtype) - return get_gradients(self._gradients, empty_grad_tensor) - - def apply_gradients( - self, grads_and_vars: dict, name: Optional[str] = None, **kwargs - ) -> tf.Operation: - train_op = super().apply_gradients(grads_and_vars, name, **kwargs) - with tf.control_dependencies([train_op]): - with tf.control_dependencies( - [ - self._optimizer.iterations.assign_add( - tf.cast( - tf.equal( - tf.math.mod( - self.step, - self._accum_steps, - ), - self._zero, - ), - tf.int64, - ), - read_value=False, - ) - ] - ): - return self.step.assign_add(1, read_value=False) - - @tf.function - def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): - return agc.adaptive_clip_grad( - [var], [grad], clip_factor=self.clip_factor - )[0] - - @tf.function - def _parse_grad( - self, accum_gradient: tf.Tensor, var: tf.Variable - ) -> tf.Tensor: - """Parses the accumulated gradient and returns the gradient to be - applied.""" - - apply_condition = tf.fill( - tf.shape(accum_gradient), - tf.equal(tf.math.mod(self.step, self._accum_steps), self._zero), - ) - - def apply_agc(): - return self._apply_agc(accum_gradient, var) - - def return_grad(): - return accum_gradient - - return tf.where( - apply_condition, - tf.cond(self._agc, apply_agc, return_grad), - tf.zeros_like(var, dtype=accum_gradient.dtype), - ) - - @tf.function - def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): - return tf.where( - tf.math.equal(grad, accum_gradient), - tf.zeros_like(accum_gradient), - accum_gradient, - ) - - def _resource_apply_dense( - self, - grad: tf.Tensor, - var: tf.Variable, - apply_state: Optional[str] = None, - ): - """ - Performs gradient update on sparse tensor. - - Parameters - ---------- - grad : tensor - Current gradient. - var : tensor - Current variable. - apply_state : str, optional - State of the optimizer. Defaults to None. - - Returns - ------- - tensor - apply_op. - - """ - accum_gradient = self.get_slot(var, "ga") - - # undo loss scaling and revert to higher precision - grad_to_use = ( - self._optimizer.get_unscaled_gradients([grad])[0] - if self.mixed_precision - else grad - ) - - # scale down the gradient and add it to the accumulated gradient - scaled_grad = tf.math.divide_no_nan( - grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) - ) - - accum_gradient.assign_add( - scaled_grad, use_locking=self._use_locking, read_value=False - ) - - def _apply(accum_gradient, var, apply_state): - grad = self._parse_grad(accum_gradient, var) - - train_op = self.base_optimizer._resource_apply_dense( - grad, - var, - apply_state=apply_state if apply_state else None, - ) - - reset_op = accum_gradient.assign( - self.reset_accum_gradient(accum_gradient, grad), - use_locking=self._use_locking, - read_value=False, - ) - - return tf.group(train_op, reset_op) - - return _apply(accum_gradient, var, apply_state) - - def _resource_apply_sparse( - self, - grad: tf.Tensor, - var: tf.Variable, - indices: tf.Tensor, - apply_state: Optional[str] = None, - ): - """Performs gradient update on sparse tensor. - - Parameters - ---------- - grad : tensor - The current gradient. - var : tensor - The current variable. - indices : tensor - Relevant indices to be used for masking the sparse tensor during - update. - apply_state : str, optional - State of the optimizer. Defaults to None. - - Returns - ------- - tensor - The operation after applying the gradient update. - - """ - - accum_gradient = self.get_slot(var, "ga") - - # undo loss scaling and revert to higher precision - grad_to_use = ( - self._optimizer.get_unscaled_gradients([grad])[0] - if self.mixed_precision - else grad - ) - - # scale down the gradient and add it to the accumulated gradient - scaled_grad = tf.math.divide_no_nan( - grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) - ) - - self._resource_scatter_add( - accum_gradient, - indices, - scaled_grad, - ) - - def _apply(accum_gradient, var, apply_state): - grad = self._parse_grad(accum_gradient, var) - - train_op = self.base_optimizer._resource_apply_sparse( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state if apply_state else None, - ) - - reset_op = accum_gradient.assign( - self.reset_accum_gradient(accum_gradient, grad), - use_locking=self._use_locking, - read_value=False, - ) - - return tf.group(train_op, reset_op) - - return _apply(accum_gradient, var, apply_state) - - def _resource_apply_sparse_duplicate_indices( - self, - grad: tf.Tensor, - var: tf.Variable, - indices: tf.Tensor, - apply_state: Optional[str] = None, - ): - """ - Performs gradient update on sparse tensor with duplicate indices. - - Parameters - ---------- - grad : tf.Tensor - Current gradient. - var : tf.Variable - Current variable. - indices : tf.Tensor - Relevant indices to be used for masking the sparse tensor during - update. - apply_state : str, optional - State of the optimizer. Defaults to None. - - """ - accum_gradient = self.get_slot(var, "ga") - - # undo loss scaling and revert to higher precision - grad_to_use = ( - self._optimizer.get_unscaled_gradients([grad])[0] - if self.mixed_precision - else grad - ) - - # scale down the gradient and add it to the accumulated gradient - scaled_grad = tf.math.divide_no_nan( - grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) - ) - - self._resource_scatter_add( - accum_gradient, - indices, - scaled_grad, - ) - - def _apply(accum_gradient, var, apply_state): - grad = self._parse_grad(accum_gradient, var) - - train_op = ( - self.base_optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state if apply_state else None, - ) - ) - - reset_op = accum_gradient.assign( - self.reset_accum_gradient(accum_gradient, grad), - use_locking=self._use_locking, - read_value=False, - ) - - return tf.group(train_op, reset_op) - - return _apply(accum_gradient, var, apply_state) - - @tf.function - def _reset_single_gradient(self, gradient: tf.Tensor): - return gradient.assign( - tf.zeros_like(gradient), - use_locking=self._use_locking, - read_value=False, - ) - - def reset(self): - """Resets the accumulated gradients on the current replica.""" - reset_ops = [ - self._reset_single_gradient(gradient) - for gradient in self._gradients - if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) +import multiprocessing as mp +from .utils import reset, get_opt + + +tf_version = int(tf.version.VERSION.split(".")[1]) + +@pytest.fixture +def generate_experiment_prerequisites(): + import os + + import tensorflow as tf + import tensorflow_datasets as tfds + + from .utils import normalize_img + + # disable GPU + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + # load dataset + (ds_train, ds_test), ds_info = tfds.load( + "mnist", + split=["train", "test"], + shuffle_files=True, + as_supervised=True, + with_info=True, + ) + + # build train pipeline + ds_train = ds_train.map(normalize_img) + ds_train = ds_train.cache() + ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples) + ds_train = ds_train.batch( + 32 + ) # multiplum of 8 on GPU to maximize performance + ds_train = ds_train.prefetch(1) + + # build test pipeline + ds_test = ds_test.map(normalize_img) + ds_test = ds_test.batch(32) + ds_test = ds_test.cache() + ds_test = ds_test.prefetch(1) + + # create model + model = tf.keras.models.Sequential( + [ + tf.keras.layers.Flatten(input_shape=(28, 28)), + tf.keras.layers.Dense(32, activation="relu"), # 32 multiplum of 8 + tf.keras.layers.Dense( + 10, dtype="float32" + ), # output not numerically stable with float16 ] - return tf.group(*reset_ops) - - @property - def optimizer(self) -> tf.keras.optimizers.Optimizer: - """The optimizer that this AccumOptimizer is wrapping. In the case of mixed - precision, this is the LossScaleOptimizer.""" - return self._optimizer - - @property - def iterations(self) -> tf.Variable: - """Returns current iteration value of optimizer.""" - return self._optimizer.iterations - - @iterations.setter - def iterations(self, variable: tf.Variable): - """Sets the iterations value of optimizer.""" - self._optimizer.iterations = variable - - @property - def lr(self) -> float: - """Returns the learning rate of the optimizer.""" - return self.base_optimizer.learning_rate - - @lr.setter - def lr(self, lr): - """Sets the learning rate of the optimizer.""" - self.base_optimizer.learning_rate = lr - self._learning_rate = lr - - @property - def learning_rate(self) -> float: - """Returns the learning rate of the optimizer.""" - return self.lr - - @learning_rate.setter - def learning_rate(self, learning_rate: float): - """Sets the learning rate of the optimizer.""" - self.lr = learning_rate - - @property - def _learning_rate(self) -> float: - """Returns the learning rate of the optimizer.""" - return self.lr - - def get_config(self) -> dict: - """Returns the configuration as a dictionary.""" - config = super().get_config() - custom_config = { - "optimizer": tf.keras.optimizers.serialize(self._optimizer), - "accum_steps": self.accum_steps, - "reduction": self.reduction, - "agc": self.agc, - "mixed_precision": self.mixed_precision, - "dtype": self.dtype.name, - } - config.update(custom_config) - return config - - @classmethod - def from_config( - cls, config: dict, custom_objects: Optional[str] = None - ) -> object: - """Creates an instance of the optimizer from its config.""" - optimizer_config = config.pop("optimizer") - optimizer = tf.keras.optimizers.deserialize( - optimizer_config, custom_objects=custom_objects - ) - return cls(optimizer=optimizer, **config) + ) + return model, ds_train, ds_test + + +def run_experiment_model(generate_experiment_prerequisites): + import tensorflow as tf + from tensorflow.keras import mixed_precision + + from gradient_accumulator import GradientAccumulateModel + + # set mixed global precision policy + mixed_precision.set_global_policy("mixed_float16") + + model, ds_train, ds_test = generate_experiment_prerequisites + + # wrap model to use gradient accumulation + model = GradientAccumulateModel( + accum_steps=4, + mixed_precision=True, + inputs=model.input, + outputs=model.output, + ) + + # need to scale optimizer for mixed precision + opt = tf.keras.optimizers.Adam(1e-3) + opt = mixed_precision.LossScaleOptimizer(opt) + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit( + ds_train, + epochs=1, + validation_data=ds_test, + ) + + # save model on disk + model.save("./trained_model") + + # load trained model and test + del model + trained_model = tf.keras.models.load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + +def run_experiment_optimizer(generate_experiment_prerequisites): + import tensorflow as tf + from tensorflow.keras import mixed_precision + + from gradient_accumulator import GradientAccumulateOptimizer + + # set mixed global precision policy + mixed_precision.set_global_policy("mixed_float16") + + model, ds_train, ds_test = generate_experiment_prerequisites + + # wrap model to use gradient accumulation + model = tf.keras.Model(inputs=model.input, outputs=model.output) + + opt = get_opt(opt_name="adam", tf_version=tf_version) + + # need to scale optimizer for mixed precision + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=True) + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit( + ds_train, + epochs=1, + validation_data=ds_test, + ) + + # save model on disk + model.save("./trained_model") + + # load trained model and test + del model + trained_model = tf.keras.models.load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + +def test_mixed_precision(): + # set seed + reset() + + # Model with mixed precision + + # launch experiment in separate process, as we are enabling mixed precision + # which will impact other unit tests, unless we do this + try: + from pytest_cov.embed import cleanup_on_sigterm + except ImportError: + pass + else: + cleanup_on_sigterm() + + try: + mp.set_start_method( + "spawn", force=True + ) # set start method to 'spawn' BEFORE instantiating the queue and the event + except RuntimeError: + pass + + p = mp.Process(target=run_experiment_model) + try: + p.start() + finally: + p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) + + reset() + + # Optimizer with mixed precision + + # launch experiment in separate process, as we are enabling mixed precision + # which will impact other unit tests, unless we do this + try: + from pytest_cov.embed import cleanup_on_sigterm + except ImportError: + pass + else: + cleanup_on_sigterm() + + try: + mp.set_start_method( + "spawn", force=True + ) # set start method to 'spawn' BEFORE instantiating the queue and the event + except RuntimeError: + pass + + p = mp.Process(target=run_experiment_optimizer) + try: + p.start() + finally: + p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) \ No newline at end of file From fa1ef2e033b7e8d40c69cc92689207cd1cada53e Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Wed, 7 Feb 2024 22:41:45 -0500 Subject: [PATCH 17/30] feat: increase test coverage --- gradient_accumulator/accumulators.py | 72 ++++++++++++++++-------- tests/test_adaptive_gradient_clipping.py | 61 +++++++++++++++++--- tests/test_mixed_precision.py | 12 ++-- tests/test_optimizer_properties.py | 52 +++++++++++++++++ 4 files changed, 160 insertions(+), 37 deletions(-) create mode 100644 tests/test_optimizer_properties.py diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 7987c57..4c5262f 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -207,7 +207,7 @@ def __init__( optimizer: str = "SGD", accum_steps: int = 1, reduction: str = "MEAN", - agc: bool = False, + use_agc: bool = False, mixed_precision: bool = False, name: str = "GradientAccumulateOptimizer", dtype: tf.dtypes.DType = tf.float32, @@ -224,7 +224,7 @@ def __init__( Update gradient in every accumulation steps, must be > 0. reduction : str, optional Gradient reduction method to use. Can be 'MEAN' or 'SUM'. - agc : bool, optional + use_agc : bool, optional Whether to use adaptive gradient clipping. Defaults to False. mixed_precision : bool, optional Whether to use mixed precision. Defaults to False. @@ -249,12 +249,14 @@ def __init__( """ super().__init__(name, **kwargs) + optimizer = tf.keras.optimizers.get(optimizer) self._optimizer = ( - tf.keras.mixed_precision.LossScaleOptimizer( - tf.keras.optimizers.get(optimizer) - ) + tf.keras.mixed_precision.LossScaleOptimizer(optimizer) if mixed_precision - else tf.keras.optimizers.get(optimizer) + and not isinstance( + optimizer, tf.keras.mixed_precision.LossScaleOptimizer + ) + else optimizer ) self.base_optimizer = ( self._optimizer.inner_optimizer @@ -275,12 +277,14 @@ def __init__( ) if not hasattr(self, "_weights"): self._weights = [] + if not hasattr(self, "_gradients"): + self._gradients = [] self._weights.append(self._step) self._zero = tf.constant(0, dtype=tf.int64) self.dtype = dtype - self.agc = agc - self._agc = tf.constant(agc) - if agc: + self.use_agc = use_agc + self._use_agc = tf.constant(use_agc) + if use_agc: if "clip_factor" in kwargs: self.clip_factor = tf.constant( kwargs.pop("clip_factor"), dtype=tf.float32 @@ -291,12 +295,17 @@ def __init__( self.clip_factor = tf.constant(0.0, dtype=tf.float32) def get_slot(self, *args, **kwargs): + """Returns a slot created by the optimizer.""" return self._optimizer.get_slot(*args, **kwargs) - def add_slot(self, *args, **kwargs): - return self._optimizer.add_slot(*args, **kwargs) + def add_slot(self, var, slot_name, initializer): + """Adds a new slot to the optimizer.""" + slot = self._optimizer.add_slot(var, slot_name, initializer=initializer) + self._gradients.append(slot) + return slot def _create_slots(self, var_list: list): + """Creates slots for the optimizer.""" # create slots using the base optimizer self.base_optimizer._create_slots(var_list=var_list) @@ -346,6 +355,21 @@ def gradients(self) -> list: def apply_gradients( self, grads_and_vars: dict, name: Optional[str] = None, **kwargs ) -> tf.Operation: + """Applies gradients to variables and updates the optimizer's state. + + Parameters + ---------- + grads_and_vars : dict + A dictionary of {gradient: variable} pairs. + name : Optional[str], optional + The name for the operation. Defaults to None. + + Returns + ------- + tf.Operation + The operation after applying the gradients. + + """ train_op = super().apply_gradients(grads_and_vars, name, **kwargs) with tf.control_dependencies([train_op]): with tf.control_dependencies( @@ -368,7 +392,8 @@ def apply_gradients( return self.step.assign_add(1, read_value=False) @tf.function - def _apply_agc(self, grad: tf.Tensor, var: tf.Variable): + def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: + """Applies adaptive gradient clipping to the gradient.""" return agc.adaptive_clip_grad( [var], [grad], clip_factor=self.clip_factor )[0] @@ -393,12 +418,13 @@ def return_grad(): return tf.where( apply_condition, - tf.cond(self._agc, apply_agc, return_grad), + tf.cond(self._use_agc, apply_agc, return_grad), tf.zeros_like(var, dtype=accum_gradient.dtype), ) @tf.function def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): + """Resets the accumulated gradient to zero after applying.""" return tf.where( tf.math.equal(grad, accum_gradient), tf.zeros_like(accum_gradient), @@ -410,7 +436,7 @@ def _resource_apply_dense( grad: tf.Tensor, var: tf.Variable, apply_state: Optional[str] = None, - ): + ) -> tf.Operation: """ Performs gradient update on sparse tensor. @@ -472,7 +498,7 @@ def _resource_apply_sparse( var: tf.Variable, indices: tf.Tensor, apply_state: Optional[str] = None, - ): + ) -> tf.Operation: """Performs gradient update on sparse tensor. Parameters @@ -540,7 +566,7 @@ def _resource_apply_sparse_duplicate_indices( var: tf.Variable, indices: tf.Tensor, apply_state: Optional[str] = None, - ): + ) -> tf.Operation: """ Performs gradient update on sparse tensor with duplicate indices. @@ -599,8 +625,8 @@ def _apply(accum_gradient, var, apply_state): return _apply(accum_gradient, var, apply_state) - @tf.function def _reset_single_gradient(self, gradient: tf.Tensor): + """Resets the accumulated gradient on the current replica.""" return gradient.assign( tf.zeros_like(gradient), use_locking=self._use_locking, @@ -644,14 +670,12 @@ def lr(self, lr): self._learning_rate = lr @property - def learning_rate(self) -> float: - """Returns the learning rate of the optimizer.""" - return self.lr + def learning_rate(self): + return self.base_optimizer.learning_rate @learning_rate.setter - def learning_rate(self, learning_rate: float): - """Sets the learning rate of the optimizer.""" - self.lr = learning_rate + def learning_rate(self, lr): + self.base_optimizer.learning_rate = lr @property def _learning_rate(self) -> float: @@ -665,7 +689,7 @@ def get_config(self) -> dict: "optimizer": tf.keras.optimizers.serialize(self._optimizer), "accum_steps": self.accum_steps, "reduction": self.reduction, - "agc": self.agc, + "use_agc": self.use_agc, "mixed_precision": self.mixed_precision, "dtype": self.dtype.name, } diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index d68a8ea..42a290d 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -1,15 +1,16 @@ -import os - +import pytest import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras import mixed_precision from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer from gradient_accumulator import unitwise_norm -from .utils import normalize_img +from .utils import normalize_img, get_opt + +tf_version = int(tf.version.VERSION.split(".")[1]) def test_unitwise_norm(): for i in range(7): @@ -27,7 +28,8 @@ def test_unitwise_norm(): raise e -def test_train_mnist(): +@pytest.fixture +def generate_experiment_prerequisites(): # load dataset (ds_train, ds_test), ds_info = tfds.load( "mnist", @@ -63,7 +65,14 @@ def test_train_mnist(): ), # output not numerically stable with float16 ] ) + return model, ds_train, ds_test + + +def test_train_mnist_model(generate_experiment_prerequisites): + + model, ds_train, ds_test = generate_experiment_prerequisites + # Test AGC with model # wrap model to use gradient accumulation model = GradientAccumulateModel( accum_steps=4, @@ -74,7 +83,44 @@ def test_train_mnist(): ) # need to scale optimizer for mixed precision - opt = tf.keras.optimizers.SGD(1e-2) + opt = get_opt(opt_name="SGD", tf_version=tf_version) + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit( + ds_train, + epochs=1, + validation_data=ds_test, + ) + + model.save("./trained_model") + + # load trained model and test + del model + trained_model = load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + +def test_train_mnist_optimizer(generate_experiment_prerequisites): + + model, ds_train, ds_test = generate_experiment_prerequisites + + + # wrap model to use gradient accumulation + model = tf.keras.Model(inputs=model.input, outputs=model.output) + + opt = get_opt(opt_name="SGD", tf_version=tf_version) + + # need to scale optimizer for mixed precision + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True) # compile model model.compile( @@ -102,4 +148,5 @@ def test_train_mnist(): # for running locally, outside pytest if __name__ == "__main__": - test_train_mnist() + test_train_mnist_model() + test_train_mnist_optimizer() diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py index 75c676b..68056ef 100644 --- a/tests/test_mixed_precision.py +++ b/tests/test_mixed_precision.py @@ -53,7 +53,7 @@ def generate_experiment_prerequisites(): ] ) return model, ds_train, ds_test - + def run_experiment_model(generate_experiment_prerequisites): import tensorflow as tf @@ -63,7 +63,7 @@ def run_experiment_model(generate_experiment_prerequisites): # set mixed global precision policy mixed_precision.set_global_policy("mixed_float16") - + model, ds_train, ds_test = generate_experiment_prerequisites # wrap model to use gradient accumulation @@ -111,14 +111,14 @@ def run_experiment_optimizer(generate_experiment_prerequisites): # set mixed global precision policy mixed_precision.set_global_policy("mixed_float16") - + model, ds_train, ds_test = generate_experiment_prerequisites # wrap model to use gradient accumulation model = tf.keras.Model(inputs=model.input, outputs=model.output) opt = get_opt(opt_name="adam", tf_version=tf_version) - + # need to scale optimizer for mixed precision opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=True) @@ -152,7 +152,7 @@ def test_mixed_precision(): reset() # Model with mixed precision - + # launch experiment in separate process, as we are enabling mixed precision # which will impact other unit tests, unless we do this try: @@ -176,7 +176,7 @@ def test_mixed_precision(): p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) reset() - + # Optimizer with mixed precision # launch experiment in separate process, as we are enabling mixed precision diff --git a/tests/test_optimizer_properties.py b/tests/test_optimizer_properties.py new file mode 100644 index 0000000..1675332 --- /dev/null +++ b/tests/test_optimizer_properties.py @@ -0,0 +1,52 @@ +import pytest +import tensorflow as tf +from gradient_accumulator import GradientAccumulateOptimizer +from .utils import get_opt + +tf_version = int(tf.version.VERSION.split(".")[1]) + + +@pytest.fixture +def optimizer(): + opt = get_opt(opt_name="SGD", tf_version=tf_version) + return GradientAccumulateOptimizer(optimizer=opt, accum_steps=2) + +def test_learning_rate_getter(optimizer): + assert optimizer.learning_rate == 0.01 + +def test_learning_rate_setter(optimizer): + optimizer.learning_rate = 0.02 + assert optimizer.learning_rate == 0.02 + +def test_lr_getter(optimizer): + assert optimizer.lr == 0.01 + +def test_lr_setter(optimizer): + optimizer.lr = 0.02 + assert optimizer.lr == 0.02 + +def test__learning_rate(optimizer): + assert optimizer._learning_rate == 0.01 + optimizer.learning_rate = 0.02 + assert optimizer._learning_rate == 0.02 + +def test_reset_single_gradient(optimizer): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + optimizer.add_slot(var, "ga", initializer=tf.constant([3.0, 4.0])) + gradient = optimizer.get_slot(var, "ga") + optimizer._reset_single_gradient(gradient) + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))) + +def test_reset(optimizer): + var1 = tf.Variable([1.0, 2.0], dtype=tf.float32) + var2 = tf.Variable([3.0, 4.0], dtype=tf.float32) + optimizer.add_slot(var1, "ga", initializer=tf.constant([5.0, 6.0])) + optimizer.add_slot(var2, "ga", initializer=tf.constant([7.0, 8.0])) + for var in [var1, var2]: + gradient = optimizer.get_slot(var, "ga") + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == False + + optimizer.reset() + for var in [var1, var2]: + gradient = optimizer.get_slot(var, "ga") + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == True \ No newline at end of file From 657b4f05da02876c587248f8bd5f3c38b15665ac Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Wed, 7 Feb 2024 23:30:55 -0500 Subject: [PATCH 18/30] feat: cover parse_grad --- ...imizer_properties.py => test_optimizer.py} | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) rename tests/{test_optimizer_properties.py => test_optimizer.py} (58%) diff --git a/tests/test_optimizer_properties.py b/tests/test_optimizer.py similarity index 58% rename from tests/test_optimizer_properties.py rename to tests/test_optimizer.py index 1675332..b521f66 100644 --- a/tests/test_optimizer_properties.py +++ b/tests/test_optimizer.py @@ -30,6 +30,12 @@ def test__learning_rate(optimizer): optimizer.learning_rate = 0.02 assert optimizer._learning_rate == 0.02 +def test_step(optimizer): + assert optimizer.step == 1 + +def test_optimizer_prop(optimizer): + assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__ + def test_reset_single_gradient(optimizer): var = tf.Variable([1.0, 2.0], dtype=tf.float32) optimizer.add_slot(var, "ga", initializer=tf.constant([3.0, 4.0])) @@ -49,4 +55,35 @@ def test_reset(optimizer): optimizer.reset() for var in [var1, var2]: gradient = optimizer.get_slot(var, "ga") - assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == True \ No newline at end of file + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == True + + +@pytest.mark.parametrize("accum_steps", [1, 2, 3]) +@pytest.mark.parametrize("use_agc", [True, False]) +def test_parse_grad(optimizer, use_agc, accum_steps): + if accum_steps == 1: + expected_grad = tf.zeros_like(var) # gradients should not be applied yet + else: + expected_grad = tf.constant([3.0, 4.0]) + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + optimizer.add_slot(var, "ga", initializer=expected_grad) + accum_gradient = optimizer.get_slot(var, "ga") + + optimizer.use_agc = use_agc + optimizer.step.assign(accum_steps) + + parsed_grad = optimizer._parse_grad(accum_gradient, var) + assert tf.reduce_all(tf.equal(parsed_grad, expected_grad)) + + +if __name__ == "__main__": + test__learning_rate(optimizer()) + test_learning_rate_getter(optimizer()) + test_learning_rate_setter(optimizer()) + test_lr_getter(optimizer()) + test_lr_setter(optimizer()) + test_step(optimizer()) + test_optimizer_prop(optimizer()) + test_reset_single_gradient(optimizer()) + test_reset(optimizer()) + test_parse_grad(optimizer()) \ No newline at end of file From e382e16a3d171a104cf6920b0e5c24aa05b9c72c Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Thu, 8 Feb 2024 00:03:58 -0500 Subject: [PATCH 19/30] fix: missing var def --- tests/test_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index b521f66..a52fd2e 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -61,11 +61,11 @@ def test_reset(optimizer): @pytest.mark.parametrize("accum_steps", [1, 2, 3]) @pytest.mark.parametrize("use_agc", [True, False]) def test_parse_grad(optimizer, use_agc, accum_steps): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) if accum_steps == 1: expected_grad = tf.zeros_like(var) # gradients should not be applied yet else: expected_grad = tf.constant([3.0, 4.0]) - var = tf.Variable([1.0, 2.0], dtype=tf.float32) optimizer.add_slot(var, "ga", initializer=expected_grad) accum_gradient = optimizer.get_slot(var, "ga") From 4263d203d6f9146a3573ef39bdb6e6c4eb2f376d Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Thu, 8 Feb 2024 00:06:00 -0500 Subject: [PATCH 20/30] chore: update codecov.yml --- .github/workflows/codecov.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 46da364..0320f8d 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -71,7 +71,8 @@ jobs: --cov=gradient_accumulator tests/test_mp_batch_norm.py \ --cov=gradient_accumulator tests/test_bn_convnd.py \ --cov=gradient_accumulator tests/test_bn_pretrained_swap.py \ - --cov=gradient_accumulator tests/test_model_distribute.py + --cov=gradient_accumulator tests/test_model_distribute.py \ + --cov=gradient_accumulator tests/test_optimizer.py - name: Lint with flake8 run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics From 13111b2246fb733728eeeeb30d7bea5d1831917c Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Thu, 8 Feb 2024 00:24:37 -0500 Subject: [PATCH 21/30] fix: failing unit test --- tests/test_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index a52fd2e..44abf6e 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -58,7 +58,7 @@ def test_reset(optimizer): assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == True -@pytest.mark.parametrize("accum_steps", [1, 2, 3]) +@pytest.mark.parametrize("accum_steps", [1, 2]) @pytest.mark.parametrize("use_agc", [True, False]) def test_parse_grad(optimizer, use_agc, accum_steps): var = tf.Variable([1.0, 2.0], dtype=tf.float32) @@ -73,7 +73,7 @@ def test_parse_grad(optimizer, use_agc, accum_steps): optimizer.step.assign(accum_steps) parsed_grad = optimizer._parse_grad(accum_gradient, var) - assert tf.reduce_all(tf.equal(parsed_grad, expected_grad)) + assert tf.reduce_all(tf.equal(parsed_grad, expected_grad)).numpy() == True if __name__ == "__main__": From edc622121cc4cfa9433ea6411ed3a4cfeb59b360 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 21:33:30 -0500 Subject: [PATCH 22/30] test: update test_optimizer with additional coverage --- tests/test_optimizer.py | 75 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 44abf6e..752d2ca 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -5,6 +5,7 @@ tf_version = int(tf.version.VERSION.split(".")[1]) +tf.config.run_functions_eagerly(True) @pytest.fixture def optimizer(): @@ -30,9 +31,17 @@ def test__learning_rate(optimizer): optimizer.learning_rate = 0.02 assert optimizer._learning_rate == 0.02 -def test_step(optimizer): +def test_step_getter(optimizer): assert optimizer.step == 1 +def test_step_setter(optimizer): + optimizer.step = 1 + assert optimizer.step == 1 + +def test_iterations_setter(optimizer): + optimizer.iterations = 1 + assert optimizer.iterations == 1 + def test_optimizer_prop(optimizer): assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__ @@ -76,14 +85,74 @@ def test_parse_grad(optimizer, use_agc, accum_steps): assert tf.reduce_all(tf.equal(parsed_grad, expected_grad)).numpy() == True +@pytest.fixture +def optimizer_with_grads(optimizer): + opt = optimizer + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + + opt.add_slot(var, "ga", initializer=tf.constant([3.0, 4.0])) + + return opt, var + +def test_reset_accum_gradient_condition(optimizer_with_grads): + optimizer, var = optimizer_with_grads + + accum_grad = optimizer.get_slot(var, "ga") + accum_grad.assign(tf.constant([3.0, 4.0], dtype=tf.float32)) + + current_grad = tf.constant([3.0, 4.0], dtype=tf.float32) + + result_grad = optimizer.reset_accum_gradient(accum_grad, current_grad) + + expected_grad = tf.zeros_like(accum_grad) + + tf.debugging.assert_equal(result_grad, expected_grad, message="Gradients should be reset to zeros") + +@pytest.fixture +def optimizer_adam(): + opt = get_opt(opt_name="adam", tf_version=tf_version) + return GradientAccumulateOptimizer(optimizer=opt, accum_steps=2) + +@pytest.fixture +def optimizer_with_sparse_grads(optimizer_adam): + opt = optimizer_adam + var = tf.Variable(tf.zeros([10, 10]), dtype=tf.float32) + + opt.add_slot(var, "ga", initializer=tf.zeros_like(var)) + opt.add_slot(var, "m", initializer=tf.zeros_like(var)) + opt.add_slot(var, "v", initializer=tf.zeros_like(var)) + + return opt, var + +def test_resource_apply_sparse(optimizer_with_sparse_grads): + optimizer, var = optimizer_with_sparse_grads + + indices = tf.constant([0, 1], dtype=tf.int64) + updates = tf.constant([[0.1] * 10, [0.2] * 10], dtype=tf.float32) + + optimizer._reset_single_gradient(optimizer.get_slot(var, "ga")) + + grad = tf.IndexedSlices(values=updates, indices=indices, dense_shape=var.shape) + + optimizer._resource_apply_sparse(grad.values, var, grad.indices) + optimizer._resource_apply_sparse(grad.values, var, grad.indices) + + accumulated_grads = optimizer.get_slot(var, "ga") + expected_accumulated_grads = tf.scatter_nd(tf.expand_dims(indices, 1), updates * 2, var.shape) / optimizer.accum_steps + tf.debugging.assert_near(accumulated_grads, expected_accumulated_grads, atol=1e-5) + + if __name__ == "__main__": test__learning_rate(optimizer()) test_learning_rate_getter(optimizer()) test_learning_rate_setter(optimizer()) test_lr_getter(optimizer()) test_lr_setter(optimizer()) - test_step(optimizer()) + test_step_getter(optimizer()) + test_step_setter(optimizer()) test_optimizer_prop(optimizer()) test_reset_single_gradient(optimizer()) test_reset(optimizer()) - test_parse_grad(optimizer()) \ No newline at end of file + test_parse_grad(optimizer()) + test_reset_accum_gradient_condition(optimizer_with_grads()) + test_resource_apply_sparse(optimizer_with_sparse_grads()) \ No newline at end of file From 8407d170295ab1f25886a1e9ba85fbd5fb8acbdd Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 22:22:35 -0500 Subject: [PATCH 23/30] test: add support for gradients property, clip factor to agc --- gradient_accumulator/accumulators.py | 7 +++-- tests/test_adaptive_gradient_clipping.py | 2 +- tests/test_optimizer.py | 33 ++++++++++++++++++++++-- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 4c5262f..f6b8ae1 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -276,7 +276,7 @@ def __init__( aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) if not hasattr(self, "_weights"): - self._weights = [] + self._weights = [] # noqa if not hasattr(self, "_gradients"): self._gradients = [] self._weights.append(self._step) @@ -341,7 +341,7 @@ def gradients(self) -> list: """Returns the current accumulated gradients on the replica.""" tf.debugging.assert_greater( tf.size(self._gradients), - self._zero, + tf.cast(self._zero, tf.int32), message="Gradients have not been computed yet. " "If you're using GradientAccumulateOptimizer with " "a custom training loop, please make sure to call " @@ -349,8 +349,7 @@ def gradients(self) -> list: "optimizer.gradients.", ) - empty_grad_tensor = tf.zeros([], dtype=self._gradient.dtype) - return get_gradients(self._gradients, empty_grad_tensor) + return get_gradients(self._gradients) def apply_gradients( self, grads_and_vars: dict, name: Optional[str] = None, **kwargs diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index 42a290d..e01a518 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -120,7 +120,7 @@ def test_train_mnist_optimizer(generate_experiment_prerequisites): opt = get_opt(opt_name="SGD", tf_version=tf_version) # need to scale optimizer for mixed precision - opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True) + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clip_factor=0.01) # compile model model.compile( diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 752d2ca..26cf214 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -24,7 +24,18 @@ def test_lr_getter(optimizer): def test_lr_setter(optimizer): optimizer.lr = 0.02 - assert optimizer.lr == 0.02 + + assert optimizer.lr == 0.02, "The lr getter did not return the updated learning rate." + + assert optimizer.base_optimizer.learning_rate.numpy() == 0.02, "The base_optimizer's learning rate was not updated correctly." + +def test_lr_setter_and_getter(optimizer): + new_learning_rate = 0.02 + optimizer.lr = new_learning_rate + + assert optimizer.base_optimizer.learning_rate.numpy() == new_learning_rate, "The base_optimizer's learning rate was not updated correctly." + + assert optimizer.lr == new_learning_rate, "The GradientAccumulateOptimizer's learning rate was not updated correctly." def test__learning_rate(optimizer): assert optimizer._learning_rate == 0.01 @@ -142,6 +153,23 @@ def test_resource_apply_sparse(optimizer_with_sparse_grads): tf.debugging.assert_near(accumulated_grads, expected_accumulated_grads, atol=1e-5) +def test_gradients_property(optimizer): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + + def loss_fn(): + return var[0]**2 + var[1]**2 + + with tf.GradientTape() as tape: + loss = loss_fn() + grads = tape.gradient(loss, [var]) + + optimizer.apply_gradients(zip(grads, [var])) + + accumulated_gradients = optimizer.gradients + + assert accumulated_gradients is not None, "Expected accumulated gradients to exist." + + if __name__ == "__main__": test__learning_rate(optimizer()) test_learning_rate_getter(optimizer()) @@ -155,4 +183,5 @@ def test_resource_apply_sparse(optimizer_with_sparse_grads): test_reset(optimizer()) test_parse_grad(optimizer()) test_reset_accum_gradient_condition(optimizer_with_grads()) - test_resource_apply_sparse(optimizer_with_sparse_grads()) \ No newline at end of file + test_resource_apply_sparse(optimizer_with_sparse_grads()) + test_gradients_property(optimizer()) \ No newline at end of file From 682b34689707d0229ea4b73ddc1ed395c3aefa4d Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 23:01:09 -0500 Subject: [PATCH 24/30] test: fix lr test --- tests/test_optimizer.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 26cf214..b41156a 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -27,15 +27,8 @@ def test_lr_setter(optimizer): assert optimizer.lr == 0.02, "The lr getter did not return the updated learning rate." - assert optimizer.base_optimizer.learning_rate.numpy() == 0.02, "The base_optimizer's learning rate was not updated correctly." - -def test_lr_setter_and_getter(optimizer): - new_learning_rate = 0.02 - optimizer.lr = new_learning_rate - - assert optimizer.base_optimizer.learning_rate.numpy() == new_learning_rate, "The base_optimizer's learning rate was not updated correctly." - - assert optimizer.lr == new_learning_rate, "The GradientAccumulateOptimizer's learning rate was not updated correctly." + assert optimizer.base_optimizer.learning_rate == 0.02 + assert optimizer._learning_rate == 0.02 def test__learning_rate(optimizer): assert optimizer._learning_rate == 0.01 @@ -52,7 +45,7 @@ def test_step_setter(optimizer): def test_iterations_setter(optimizer): optimizer.iterations = 1 assert optimizer.iterations == 1 - + def test_optimizer_prop(optimizer): assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__ @@ -128,11 +121,11 @@ def optimizer_adam(): def optimizer_with_sparse_grads(optimizer_adam): opt = optimizer_adam var = tf.Variable(tf.zeros([10, 10]), dtype=tf.float32) - + opt.add_slot(var, "ga", initializer=tf.zeros_like(var)) opt.add_slot(var, "m", initializer=tf.zeros_like(var)) opt.add_slot(var, "v", initializer=tf.zeros_like(var)) - + return opt, var def test_resource_apply_sparse(optimizer_with_sparse_grads): @@ -152,7 +145,6 @@ def test_resource_apply_sparse(optimizer_with_sparse_grads): expected_accumulated_grads = tf.scatter_nd(tf.expand_dims(indices, 1), updates * 2, var.shape) / optimizer.accum_steps tf.debugging.assert_near(accumulated_grads, expected_accumulated_grads, atol=1e-5) - def test_gradients_property(optimizer): var = tf.Variable([1.0, 2.0], dtype=tf.float32) From e97a43cd139cd4cc334b87adb3d18f3abd162c54 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 23:17:05 -0500 Subject: [PATCH 25/30] chore: lint --- gradient_accumulator/accumulators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index f6b8ae1..fb94f93 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -276,7 +276,7 @@ def __init__( aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) if not hasattr(self, "_weights"): - self._weights = [] # noqa + self._weights = [] # noqa if not hasattr(self, "_gradients"): self._gradients = [] self._weights.append(self._step) From 03d90763170e8b4f1bedee368e6b5cc9d6d03abd Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 23:30:24 -0500 Subject: [PATCH 26/30] fix: clip_factor -> clipvalue --- gradient_accumulator/accumulators.py | 14 +++++++------- tests/test_adaptive_gradient_clipping.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index fb94f93..94dbd1e 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -233,7 +233,7 @@ def __init__( "GradientAccumulateOptimizer". **kwargs : dict Additional keyword arguments. Allowed keys are: - - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. + - `clipvalue`: Sets upper limit for gradient clipping. Defaults to 0.01. - `lr`: Learning rate, included for backward compatibility. Use `learning_rate` instead. @@ -285,14 +285,14 @@ def __init__( self.use_agc = use_agc self._use_agc = tf.constant(use_agc) if use_agc: - if "clip_factor" in kwargs: - self.clip_factor = tf.constant( - kwargs.pop("clip_factor"), dtype=tf.float32 + if "clipvalue" in kwargs: + self.clipvalue = tf.constant( + kwargs.pop("clipvalue"), dtype=tf.float32 ) else: - self.clip_factor = tf.constant(0.01, dtype=tf.float32) + self.clipvalue = tf.constant(0.01, dtype=tf.float32) else: - self.clip_factor = tf.constant(0.0, dtype=tf.float32) + self.clipvalue = tf.constant(0.0, dtype=tf.float32) def get_slot(self, *args, **kwargs): """Returns a slot created by the optimizer.""" @@ -394,7 +394,7 @@ def apply_gradients( def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: """Applies adaptive gradient clipping to the gradient.""" return agc.adaptive_clip_grad( - [var], [grad], clip_factor=self.clip_factor + [var], [grad], clipvalue=self.clipvalue )[0] @tf.function diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index e01a518..9778da5 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -120,7 +120,7 @@ def test_train_mnist_optimizer(generate_experiment_prerequisites): opt = get_opt(opt_name="SGD", tf_version=tf_version) # need to scale optimizer for mixed precision - opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clip_factor=0.01) + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clipvalue=0.01) # compile model model.compile( From da1c3462ed8d31c5a615147380c67ee42d7cf217 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 23:31:31 -0500 Subject: [PATCH 27/30] chore: lint --- gradient_accumulator/accumulators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 94dbd1e..b1e2533 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -393,9 +393,9 @@ def apply_gradients( @tf.function def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: """Applies adaptive gradient clipping to the gradient.""" - return agc.adaptive_clip_grad( - [var], [grad], clipvalue=self.clipvalue - )[0] + return agc.adaptive_clip_grad([var], [grad], clipvalue=self.clipvalue)[ + 0 + ] @tf.function def _parse_grad( From 7188e30eca2c558c21cb00aa8d58ca402cfaa9fb Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sat, 10 Feb 2024 23:35:20 -0500 Subject: [PATCH 28/30] chore: lint --- gradient_accumulator/accumulators.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index b1e2533..7090088 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -393,9 +393,9 @@ def apply_gradients( @tf.function def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: """Applies adaptive gradient clipping to the gradient.""" - return agc.adaptive_clip_grad([var], [grad], clipvalue=self.clipvalue)[ - 0 - ] + return agc.adaptive_clip_grad( + [var], [grad], clip_factor=self.clipvalue + )[0] @tf.function def _parse_grad( From 8d87ab9ad3140f22f9810a5480ac05854d253c31 Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 11 Feb 2024 00:20:09 -0500 Subject: [PATCH 29/30] chore: lint --- gradient_accumulator/accumulators.py | 14 +++++--------- tests/test_adaptive_gradient_clipping.py | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 7090088..4fc784d 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -233,7 +233,7 @@ def __init__( "GradientAccumulateOptimizer". **kwargs : dict Additional keyword arguments. Allowed keys are: - - `clipvalue`: Sets upper limit for gradient clipping. Defaults to 0.01. + - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. - `lr`: Learning rate, included for backward compatibility. Use `learning_rate` instead. @@ -248,6 +248,7 @@ def __init__( .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa """ + clip_factor = kwargs.pop("clip_factor", 0.01) super().__init__(name, **kwargs) optimizer = tf.keras.optimizers.get(optimizer) self._optimizer = ( @@ -285,14 +286,9 @@ def __init__( self.use_agc = use_agc self._use_agc = tf.constant(use_agc) if use_agc: - if "clipvalue" in kwargs: - self.clipvalue = tf.constant( - kwargs.pop("clipvalue"), dtype=tf.float32 - ) - else: - self.clipvalue = tf.constant(0.01, dtype=tf.float32) + self.clip_factor = tf.constant(clip_factor, dtype=tf.float32) else: - self.clipvalue = tf.constant(0.0, dtype=tf.float32) + self.clip_factor = tf.constant(0.0, dtype=tf.float32) def get_slot(self, *args, **kwargs): """Returns a slot created by the optimizer.""" @@ -394,7 +390,7 @@ def apply_gradients( def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: """Applies adaptive gradient clipping to the gradient.""" return agc.adaptive_clip_grad( - [var], [grad], clip_factor=self.clipvalue + [var], [grad], clip_factor=self.clip_factor )[0] @tf.function diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index 9778da5..e01a518 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -120,7 +120,7 @@ def test_train_mnist_optimizer(generate_experiment_prerequisites): opt = get_opt(opt_name="SGD", tf_version=tf_version) # need to scale optimizer for mixed precision - opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clipvalue=0.01) + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clip_factor=0.01) # compile model model.compile( From 58e42e826ee9a57d9fd25dfa8bbf65f7edff054a Mon Sep 17 00:00:00 2001 From: Derek Pisner Date: Sun, 11 Feb 2024 01:01:18 -0500 Subject: [PATCH 30/30] test(cov): lr setter --- gradient_accumulator/accumulators.py | 2 +- tests/test_optimizer.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 4fc784d..bf2e9fc 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -277,7 +277,7 @@ def __init__( aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) if not hasattr(self, "_weights"): - self._weights = [] # noqa + self._weights = [] # pragma: no cover if not hasattr(self, "_gradients"): self._gradients = [] self._weights.append(self._step) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index b41156a..9622168 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -24,11 +24,17 @@ def test_lr_getter(optimizer): def test_lr_setter(optimizer): optimizer.lr = 0.02 + updated_lr = optimizer.lr.numpy() if hasattr(optimizer.lr, 'numpy') else optimizer.lr - assert optimizer.lr == 0.02, "The lr getter did not return the updated learning rate." + assert updated_lr == pytest.approx(0.02), "The lr getter did not return the updated learning rate." - assert optimizer.base_optimizer.learning_rate == 0.02 - assert optimizer._learning_rate == 0.02 + base_lr = optimizer.base_optimizer.learning_rate + base_lr = base_lr.numpy() if hasattr(base_lr, 'numpy') else base_lr + assert base_lr == pytest.approx(0.02), "The base_optimizer's learning rate was not updated." + + internal_lr = optimizer._learning_rate + internal_lr = internal_lr.numpy() if hasattr(internal_lr, 'numpy') else internal_lr + assert internal_lr == pytest.approx(0.02), "The internal _learning_rate attribute was not updated." def test__learning_rate(optimizer): assert optimizer._learning_rate == 0.01 @@ -45,7 +51,7 @@ def test_step_setter(optimizer): def test_iterations_setter(optimizer): optimizer.iterations = 1 assert optimizer.iterations == 1 - + def test_optimizer_prop(optimizer): assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__