diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 46da364..0320f8d 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -71,7 +71,8 @@ jobs: --cov=gradient_accumulator tests/test_mp_batch_norm.py \ --cov=gradient_accumulator tests/test_bn_convnd.py \ --cov=gradient_accumulator tests/test_bn_pretrained_swap.py \ - --cov=gradient_accumulator tests/test_model_distribute.py + --cov=gradient_accumulator tests/test_model_distribute.py \ + --cov=gradient_accumulator tests/test_optimizer.py - name: Lint with flake8 run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics diff --git a/gradient_accumulator/accumulators.py b/gradient_accumulator/accumulators.py index 7baad13..bf2e9fc 100644 --- a/gradient_accumulator/accumulators.py +++ b/gradient_accumulator/accumulators.py @@ -1,6 +1,9 @@ +from typing import Optional + import tensorflow as tf from . import agc +from .utils import get_gradients # dynamically handle which Optimizer class to use dep on tf version opt = tf.keras.optimizers.Optimizer @@ -9,7 +12,7 @@ # https://stackoverflow.com/a/66524901 -# https://keras.io/guides/customizing_what_happens_in_fit/ +# https://keras.io/guides/customizing_what_happens_in_fit/ # noqa @tf.keras.utils.register_keras_serializable("gradient-accumulator") class GradientAccumulateModel(tf.keras.Model): """Model wrapper for gradient accumulation.""" @@ -23,7 +26,7 @@ def __init__( eps: float = 1e-3, experimental_distributed_support: bool = False, *args, - **kwargs + **kwargs, ): """Adds gradient accumulation support to existing Keras Model. @@ -201,114 +204,166 @@ class GradientAccumulateOptimizer(opt): def __init__( self, - optimizer="SGD", - accum_steps=1, + optimizer: str = "SGD", + accum_steps: int = 1, reduction: str = "MEAN", + use_agc: bool = False, + mixed_precision: bool = False, name: str = "GradientAccumulateOptimizer", - **kwargs + dtype: tf.dtypes.DType = tf.float32, + **kwargs, ): - """Construct a new GradientAccumulateOptimizer optimizer. - - Adding support for sparse tensors was tricky, but this resource was - helpful. Note that you need to implement both _resource_apply_sparse() - and _resource_apply_sparse_duplicate_indices() for it to work as - intended. - - See here for more information regarding implementation: - * https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa + """ + Construct a new GradientAccumulateOptimizer optimizer. + + Parameters + ---------- + optimizer : str or tf.keras.optimizers.Optimizer + Optimizer that will be used to compute and apply gradients. + accum_steps : int, optional + Update gradient in every accumulation steps, must be > 0. + reduction : str, optional + Gradient reduction method to use. Can be 'MEAN' or 'SUM'. + use_agc : bool, optional + Whether to use adaptive gradient clipping. Defaults to False. + mixed_precision : bool, optional + Whether to use mixed precision. Defaults to False. + name : str, optional + Name for the operations created when applying gradients. Defaults to + "GradientAccumulateOptimizer". + **kwargs : dict + Additional keyword arguments. Allowed keys are: + - `clip_factor`: Sets upper limit for gradient clipping. Defaults to 0.01. + - `lr`: Learning rate, included for backward compatibility. Use + `learning_rate` instead. + + Notes + ----- + Adding support for sparse tensors was tricky. For correct implementation, both + `_resource_apply_sparse()` + and `_resource_apply_sparse_duplicate_indices()` methods need to be implemented. + + References + ---------- + .. [1] https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93 # noqa - Args: - optimizer: str or `tf.keras.optimizers.Optimizer` that will be - used to compute and apply gradients. - accum_steps: int > 0. Update gradient in every accumulation steps. - reduction: str. Which gradient reduction method to use. Defaults - to 'SUM'. - name: Optional name for the operations created when applying - gradients. Defaults to "GradientAccumulateOptimizer". - **kwargs: keyword arguments. Allowed to be {`clipnorm`, - `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by - norm; `clipvalue` is clip gradients by value, `decay` is - included for backward compatibility to allow time inverse - decay of learning rate. `lr` is included for backward - compatibility, recommended to use `learning_rate` instead. """ - self._optimizer = tf.keras.optimizers.get(optimizer) - self._accum_steps = accum_steps - self._reduction = reduction - self._step = None + clip_factor = kwargs.pop("clip_factor", 0.01) super().__init__(name, **kwargs) + optimizer = tf.keras.optimizers.get(optimizer) + self._optimizer = ( + tf.keras.mixed_precision.LossScaleOptimizer(optimizer) + if mixed_precision + and not isinstance( + optimizer, tf.keras.mixed_precision.LossScaleOptimizer + ) + else optimizer + ) + self.base_optimizer = ( + self._optimizer.inner_optimizer + if mixed_precision + else self._optimizer + ) + self.mixed_precision = mixed_precision + self._mixed_precision = tf.constant(mixed_precision, dtype=tf.bool) + self.accum_steps = accum_steps + self._accum_steps = tf.constant(accum_steps, dtype=tf.int64) + self.reduction = reduction + self._reduction = tf.constant(reduction, dtype=tf.string) + self._step = tf.Variable( + initial_value=1, + trainable=False, + dtype=tf.int64, + aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, + ) + if not hasattr(self, "_weights"): + self._weights = [] # pragma: no cover + if not hasattr(self, "_gradients"): + self._gradients = [] + self._weights.append(self._step) + self._zero = tf.constant(0, dtype=tf.int64) + self.dtype = dtype + self.use_agc = use_agc + self._use_agc = tf.constant(use_agc) + if use_agc: + self.clip_factor = tf.constant(clip_factor, dtype=tf.float32) + else: + self.clip_factor = tf.constant(0.0, dtype=tf.float32) - def _create_slots(self, var_list): - """Creates slots for optimizer gradients. + def get_slot(self, *args, **kwargs): + """Returns a slot created by the optimizer.""" + return self._optimizer.get_slot(*args, **kwargs) + + def add_slot(self, var, slot_name, initializer): + """Adds a new slot to the optimizer.""" + slot = self._optimizer.add_slot(var, slot_name, initializer=initializer) + self._gradients.append(slot) + return slot + + def _create_slots(self, var_list: list): + """Creates slots for the optimizer.""" + # create slots using the base optimizer + self.base_optimizer._create_slots(var_list=var_list) + + base_optimizer_slots = self.base_optimizer.get_slot_names() - Args: - List of trainable variables. - """ - self._optimizer._create_slots(var_list=var_list) for var in var_list: - self.add_slot(var, "ga") + for slot_name in base_optimizer_slots: + self.add_slot( + var, + slot_name, + initializer=tf.zeros_like(var), + ) + + # create slots for accumulated gradients + for var in var_list: + self.add_slot(var, "ga", initializer=tf.zeros_like(var)) self._gradients = [self.get_slot(var, "ga") for var in var_list] @property - def step(self): - """The number of training steps this Optimizer has run. - Initializes step variable if None. - - Returns: - Current number of optimizer steps. - """ - if self._step is None: - with self._distribution_strategy_scope(): - self._step = self.add_weight( - "iter", - shape=[], - initializer="ones", - dtype=tf.int64, - trainable=False, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - ) - self._weights.append(self._step) + def step(self) -> tf.Variable: + """Returns the number of training steps this Optimizer has run.""" return self._step @step.setter - def step(self, variable): # pragma: no cover + def step(self, variable: tf.Variable): """Sets the step value.""" - if self._step is not None: - raise RuntimeError( - "Cannot set `step` to a new Variable after " - "the Optimizer weights have been created" - ) self._step = variable self._weights.append(self._step) @property - def gradients(self): # pragma: no cover - """The accumulated gradients on the current replica. - - Returns: - Current gradients in optimizer. - """ - if not self._gradients: - raise ValueError( - "The accumulator should be called first to initialize the" - "gradients" - ) - return list( - gradient.read_value() if gradient is not None else gradient - for gradient in self._gradients + def gradients(self) -> list: + """Returns the current accumulated gradients on the replica.""" + tf.debugging.assert_greater( + tf.size(self._gradients), + tf.cast(self._zero, tf.int32), + message="Gradients have not been computed yet. " + "If you're using GradientAccumulateOptimizer with " + "a custom training loop, please make sure to call " + "optimizer.apply_gradients() before accessing " + "optimizer.gradients.", ) - def apply_gradients(self, grads_and_vars, name=None, **kwargs): - """Updates weights using gradients. + return get_gradients(self._gradients) + + def apply_gradients( + self, grads_and_vars: dict, name: Optional[str] = None, **kwargs + ) -> tf.Operation: + """Applies gradients to variables and updates the optimizer's state. + + Parameters + ---------- + grads_and_vars : dict + A dictionary of {gradient: variable} pairs. + name : Optional[str], optional + The name for the operation. Defaults to None. + + Returns + ------- + tf.Operation + The operation after applying the gradients. - Args: - grads_and_vars: dict containing variables and corresponding - gradients. - name: name to set when applying gradients. - **kwargs: keyword arguments. - Return: - Updated weights. """ train_op = super().apply_gradients(grads_and_vars, name, **kwargs) with tf.control_dependencies([train_op]): @@ -316,7 +371,13 @@ def apply_gradients(self, grads_and_vars, name=None, **kwargs): [ self._optimizer.iterations.assign_add( tf.cast( - tf.where(self.step % self._accum_steps == 0, 1, 0), + tf.equal( + tf.math.mod( + self.step, + self._accum_steps, + ), + self._zero, + ), tf.int64, ), read_value=False, @@ -325,47 +386,99 @@ def apply_gradients(self, grads_and_vars, name=None, **kwargs): ): return self.step.assign_add(1, read_value=False) - def _resource_apply_dense( - self, grad, var, apply_state=None - ): # pragma: no cover - """Performs gradient update on dense tensor. + @tf.function + def _apply_agc(self, grad: tf.Tensor, var: tf.Variable) -> tf.Tensor: + """Applies adaptive gradient clipping to the gradient.""" + return agc.adaptive_clip_grad( + [var], [grad], clip_factor=self.clip_factor + )[0] + + @tf.function + def _parse_grad( + self, accum_gradient: tf.Tensor, var: tf.Variable + ) -> tf.Tensor: + """Parses the accumulated gradient and returns the gradient to be + applied.""" + + apply_condition = tf.fill( + tf.shape(accum_gradient), + tf.equal(tf.math.mod(self.step, self._accum_steps), self._zero), + ) - Args: - grad: current gradient. - var: current variable. - apply_state: whether to apply X. - Returns: + def apply_agc(): + return self._apply_agc(accum_gradient, var) + + def return_grad(): + return accum_gradient + + return tf.where( + apply_condition, + tf.cond(self._use_agc, apply_agc, return_grad), + tf.zeros_like(var, dtype=accum_gradient.dtype), + ) + + @tf.function + def reset_accum_gradient(self, accum_gradient: tf.Tensor, grad: tf.Tensor): + """Resets the accumulated gradient to zero after applying.""" + return tf.where( + tf.math.equal(grad, accum_gradient), + tf.zeros_like(accum_gradient), + accum_gradient, + ) + + def _resource_apply_dense( + self, + grad: tf.Tensor, + var: tf.Variable, + apply_state: Optional[str] = None, + ) -> tf.Operation: + """ + Performs gradient update on sparse tensor. + + Parameters + ---------- + grad : tensor + Current gradient. + var : tensor + Current variable. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor apply_op. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - accum_gradient.assign_add( - grad / self._accum_steps, - use_locking=self._use_locking, - read_value=False, - ) - def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), - ) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) - if "apply_state" in self._optimizer._dense_apply_args: - train_op = self._optimizer._resource_apply_dense( - grad, var, apply_state=apply_state - ) - else: - train_op = self.optimizer._resource_apply_dense(grad, var) + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + accum_gradient.assign_add( + scaled_grad, use_locking=self._use_locking, read_value=False + ) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, + def _apply(accum_gradient, var, apply_state): + grad = self._parse_grad(accum_gradient, var) + + train_op = self.base_optimizer._resource_apply_dense( + grad, + var, + apply_state=apply_state if apply_state else None, ) + reset_op = accum_gradient.assign( - reset_val, + self.reset_accum_gradient(accum_gradient, grad), use_locking=self._use_locking, read_value=False, ) @@ -375,50 +488,65 @@ def _apply(accum_gradient, var, apply_state): return _apply(accum_gradient, var, apply_state) def _resource_apply_sparse( - self, grad, var, indices, apply_state=None - ): # pragma: no cover + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ) -> tf.Operation: """Performs gradient update on sparse tensor. - Args: - grad: current gradient. - var: current variable. - indices: relevant indices to be used for masking the sparse tensor - during update. - Returns: - apply_op. + Parameters + ---------- + grad : tensor + The current gradient. + var : tensor + The current variable. + indices : tensor + Relevant indices to be used for masking the sparse tensor during + update. + apply_state : str, optional + State of the optimizer. Defaults to None. + + Returns + ------- + tensor + The operation after applying the gradient update. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - grad /= tf.cast(self._accum_steps, dtype=grad.dtype) - self._resource_scatter_add(accum_gradient, indices, grad) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), - ) - if "apply_state" in self.optimizer._sparse_apply_args: - train_op = self.optimizer._resource_apply_sparse( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state, - ) - else: - train_op = self.optimizer._resource_apply_sparse( - accum_gradient.sparse_read(indices), var, indices - ) + grad = self._parse_grad(accum_gradient, var) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, + train_op = self.base_optimizer._resource_apply_sparse( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state if apply_state else None, ) + reset_op = accum_gradient.assign( - reset_val, + self.reset_accum_gradient(accum_gradient, grad), use_locking=self._use_locking, read_value=False, ) @@ -427,56 +555,63 @@ def _apply(accum_gradient, var, apply_state): return _apply(accum_gradient, var, apply_state) - # TODO: needs to be updated and tested def _resource_apply_sparse_duplicate_indices( - self, grad, var, indices, apply_state=None - ): # pragma: no cover - """Performs gradient update on sparse tensor. - - Args: - grad: current gradient. - var: current variable. - indices: relevant indices to be used for masking the sparse tensor - during update. - Returns: - apply_op. + self, + grad: tf.Tensor, + var: tf.Variable, + indices: tf.Tensor, + apply_state: Optional[str] = None, + ) -> tf.Operation: """ + Performs gradient update on sparse tensor with duplicate indices. + + Parameters + ---------- + grad : tf.Tensor + Current gradient. + var : tf.Variable + Current variable. + indices : tf.Tensor + Relevant indices to be used for masking the sparse tensor during + update. + apply_state : str, optional + State of the optimizer. Defaults to None. + """ accum_gradient = self.get_slot(var, "ga") - if accum_gradient is not None and grad is not None: - grad /= tf.cast(self._accum_steps, dtype=grad.dtype) - self._resource_scatter_add(accum_gradient, indices, grad) + # undo loss scaling and revert to higher precision + grad_to_use = ( + self._optimizer.get_unscaled_gradients([grad])[0] + if self.mixed_precision + else grad + ) + + # scale down the gradient and add it to the accumulated gradient + scaled_grad = tf.math.divide_no_nan( + grad_to_use, tf.cast(self._accum_steps, dtype=grad_to_use.dtype) + ) + + self._resource_scatter_add( + accum_gradient, + indices, + scaled_grad, + ) def _apply(accum_gradient, var, apply_state): - grad = tf.where( - self.step % self._accum_steps == 0, - accum_gradient, - tf.zeros_like(var), - ) - if "apply_state" in self.optimizer._sparse_apply_args: - train_op = ( - self.optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), - var, - indices, - apply_state=apply_state, - ) - ) - else: - train_op = ( - self.optimizer._resource_apply_sparse_duplicate_indices( - accum_gradient.sparse_read(indices), var, indices - ) - ) + grad = self._parse_grad(accum_gradient, var) - reset_val = tf.where( - grad == accum_gradient, - tf.zeros_like(accum_gradient), - accum_gradient, + train_op = ( + self.base_optimizer._resource_apply_sparse_duplicate_indices( + accum_gradient.sparse_read(indices), + var, + indices, + apply_state=apply_state if apply_state else None, + ) ) + reset_op = accum_gradient.assign( - reset_val, + self.reset_accum_gradient(accum_gradient, grad), use_locking=self._use_locking, read_value=False, ) @@ -485,74 +620,84 @@ def _apply(accum_gradient, var, apply_state): return _apply(accum_gradient, var, apply_state) - def reset(self): # pragma: no cover - """Resets the accumulated gradients on the current replica.""" - assign_ops = [] - if not self._gradients: - return assign_ops - - for gradient in self._gradients: - if gradient is not None: - assign_ops.append( - gradient.assign( - tf.zeros_like(gradient), - use_locking=self._use_locking, - read_value=False, - ) - ) + def _reset_single_gradient(self, gradient: tf.Tensor): + """Resets the accumulated gradient on the current replica.""" + return gradient.assign( + tf.zeros_like(gradient), + use_locking=self._use_locking, + read_value=False, + ) - return tf.group(assign_ops) + def reset(self): + """Resets the accumulated gradients on the current replica.""" + reset_ops = [ + self._reset_single_gradient(gradient) + for gradient in self._gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) + ] + return tf.group(*reset_ops) @property - def optimizer(self): - """The optimizer that this AccumOptimizer is wrapping.""" + def optimizer(self) -> tf.keras.optimizers.Optimizer: + """The optimizer that this AccumOptimizer is wrapping. In the case of mixed + precision, this is the LossScaleOptimizer.""" return self._optimizer @property - def iterations(self): - """Returns current iteration value of optimizer. - - Returns: - iterations of optimizer.""" + def iterations(self) -> tf.Variable: + """Returns current iteration value of optimizer.""" return self._optimizer.iterations @iterations.setter - def iterations(self, variable): + def iterations(self, variable: tf.Variable): """Sets the iterations value of optimizer.""" self._optimizer.iterations = variable @property - def learning_rate(self): # pragma: no cover - """Returns the learning rate of the optimizer. + def lr(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.base_optimizer.learning_rate - Returns: - learning rate of optimizer. - """ - return self._optimizer._get_hyper("learning_rate") + @lr.setter + def lr(self, lr): + """Sets the learning rate of the optimizer.""" + self.base_optimizer.learning_rate = lr + self._learning_rate = lr - @learning_rate.setter - def learning_rate(self, learning_rate): # pragma: no cover - """Sets the learning rate of the optimizer. + @property + def learning_rate(self): + return self.base_optimizer.learning_rate - Args: - learning_rate: which learning rate to set in the optimizer. - """ - self._optimizer._set_hyper("learning_rate", learning_rate) + @learning_rate.setter + def learning_rate(self, lr): + self.base_optimizer.learning_rate = lr - def get_config(self): - """Returns the configuration as dict.""" - config = { + @property + def _learning_rate(self) -> float: + """Returns the learning rate of the optimizer.""" + return self.lr + + def get_config(self) -> dict: + """Returns the configuration as a dictionary.""" + config = super().get_config() + custom_config = { "optimizer": tf.keras.optimizers.serialize(self._optimizer), - "accum_steps": self._accum_steps, - "reduction": self._reduction, + "accum_steps": self.accum_steps, + "reduction": self.reduction, + "use_agc": self.use_agc, + "mixed_precision": self.mixed_precision, + "dtype": self.dtype.name, } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + config.update(custom_config) + return config @classmethod - def from_config(cls, config, custom_objects=None): - """Gets config of original optimizer and deserializes it.""" + def from_config( + cls, config: dict, custom_objects: Optional[str] = None + ) -> object: + """Creates an instance of the optimizer from its config.""" + optimizer_config = config.pop("optimizer") optimizer = tf.keras.optimizers.deserialize( - config.pop("optimizer"), custom_objects=custom_objects + optimizer_config, custom_objects=custom_objects ) - return cls(optimizer, **config) + return cls(optimizer=optimizer, **config) diff --git a/gradient_accumulator/utils.py b/gradient_accumulator/utils.py index 14785eb..4e0bb5b 100644 --- a/gradient_accumulator/utils.py +++ b/gradient_accumulator/utils.py @@ -73,3 +73,11 @@ def replace_batchnorm_layers(model, accum_steps, position="replace"): model_outputs.append(x) return tf.keras.Model(inputs=model.inputs, outputs=x) + + +def get_gradients(gradients: list): + return [ + gradient.read_value() + for gradient in gradients + if tf.reduce_all(tf.not_equal(tf.size(gradient), 0)) + ] diff --git a/setup.py b/setup.py index bb68799..ba5b29c 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setuptools.setup( name="gradient-accumulator", version="0.5.2", - author="André Pedersen and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo", + author="André Pedersen and Derek Alexander and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo", author_email="andrped94@gmail.com", description="Package for gradient accumulation in TensorFlow", long_description=long_description, diff --git a/shell/format.sh b/shell/format.sh old mode 100644 new mode 100755 diff --git a/shell/lint.sh b/shell/lint.sh old mode 100644 new mode 100755 diff --git a/tests/test_adaptive_gradient_clipping.py b/tests/test_adaptive_gradient_clipping.py index d68a8ea..e01a518 100644 --- a/tests/test_adaptive_gradient_clipping.py +++ b/tests/test_adaptive_gradient_clipping.py @@ -1,15 +1,16 @@ -import os - +import pytest import tensorflow as tf import tensorflow_datasets as tfds -from tensorflow.keras import mixed_precision from tensorflow.keras.models import load_model from gradient_accumulator import GradientAccumulateModel +from gradient_accumulator import GradientAccumulateOptimizer from gradient_accumulator import unitwise_norm -from .utils import normalize_img +from .utils import normalize_img, get_opt + +tf_version = int(tf.version.VERSION.split(".")[1]) def test_unitwise_norm(): for i in range(7): @@ -27,7 +28,8 @@ def test_unitwise_norm(): raise e -def test_train_mnist(): +@pytest.fixture +def generate_experiment_prerequisites(): # load dataset (ds_train, ds_test), ds_info = tfds.load( "mnist", @@ -63,7 +65,14 @@ def test_train_mnist(): ), # output not numerically stable with float16 ] ) + return model, ds_train, ds_test + + +def test_train_mnist_model(generate_experiment_prerequisites): + + model, ds_train, ds_test = generate_experiment_prerequisites + # Test AGC with model # wrap model to use gradient accumulation model = GradientAccumulateModel( accum_steps=4, @@ -74,7 +83,44 @@ def test_train_mnist(): ) # need to scale optimizer for mixed precision - opt = tf.keras.optimizers.SGD(1e-2) + opt = get_opt(opt_name="SGD", tf_version=tf_version) + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit( + ds_train, + epochs=1, + validation_data=ds_test, + ) + + model.save("./trained_model") + + # load trained model and test + del model + trained_model = load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + +def test_train_mnist_optimizer(generate_experiment_prerequisites): + + model, ds_train, ds_test = generate_experiment_prerequisites + + + # wrap model to use gradient accumulation + model = tf.keras.Model(inputs=model.input, outputs=model.output) + + opt = get_opt(opt_name="SGD", tf_version=tf_version) + + # need to scale optimizer for mixed precision + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=False, use_agc=True, clip_factor=0.01) # compile model model.compile( @@ -102,4 +148,5 @@ def test_train_mnist(): # for running locally, outside pytest if __name__ == "__main__": - test_train_mnist() + test_train_mnist_model() + test_train_mnist_optimizer() diff --git a/tests/test_mixed_precision.py b/tests/test_mixed_precision.py index d24d50a..68056ef 100644 --- a/tests/test_mixed_precision.py +++ b/tests/test_mixed_precision.py @@ -1,23 +1,23 @@ +import pytest +import tensorflow as tf import multiprocessing as mp +from .utils import reset, get_opt -def run_experiment(): +tf_version = int(tf.version.VERSION.split(".")[1]) + +@pytest.fixture +def generate_experiment_prerequisites(): import os import tensorflow as tf import tensorflow_datasets as tfds - from tensorflow.keras import mixed_precision - - from gradient_accumulator import GradientAccumulateModel from .utils import normalize_img # disable GPU os.environ["CUDA_VISIBLE_DEVICES"] = "-1" - # set mixed global precision policy - mixed_precision.set_global_policy("mixed_float16") - # load dataset (ds_train, ds_test), ds_info = tfds.load( "mnist", @@ -52,6 +52,19 @@ def run_experiment(): ), # output not numerically stable with float16 ] ) + return model, ds_train, ds_test + + +def run_experiment_model(generate_experiment_prerequisites): + import tensorflow as tf + from tensorflow.keras import mixed_precision + + from gradient_accumulator import GradientAccumulateModel + + # set mixed global precision policy + mixed_precision.set_global_policy("mixed_float16") + + model, ds_train, ds_test = generate_experiment_prerequisites # wrap model to use gradient accumulation model = GradientAccumulateModel( @@ -90,7 +103,56 @@ def run_experiment(): print(result) +def run_experiment_optimizer(generate_experiment_prerequisites): + import tensorflow as tf + from tensorflow.keras import mixed_precision + + from gradient_accumulator import GradientAccumulateOptimizer + + # set mixed global precision policy + mixed_precision.set_global_policy("mixed_float16") + + model, ds_train, ds_test = generate_experiment_prerequisites + + # wrap model to use gradient accumulation + model = tf.keras.Model(inputs=model.input, outputs=model.output) + + opt = get_opt(opt_name="adam", tf_version=tf_version) + + # need to scale optimizer for mixed precision + opt = GradientAccumulateOptimizer(opt, accum_steps=4, mixed_precision=True) + + # compile model + model.compile( + optimizer=opt, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], + ) + + # train model + model.fit( + ds_train, + epochs=1, + validation_data=ds_test, + ) + + # save model on disk + model.save("./trained_model") + + # load trained model and test + del model + trained_model = tf.keras.models.load_model("./trained_model", compile=True) + + result = trained_model.evaluate(ds_test, verbose=1) + print(result) + + def test_mixed_precision(): + # set seed + reset() + + # Model with mixed precision + # launch experiment in separate process, as we are enabling mixed precision # which will impact other unit tests, unless we do this try: @@ -107,8 +169,34 @@ def test_mixed_precision(): except RuntimeError: pass - p = mp.Process(target=run_experiment) + p = mp.Process(target=run_experiment_model) try: p.start() finally: p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) + + reset() + + # Optimizer with mixed precision + + # launch experiment in separate process, as we are enabling mixed precision + # which will impact other unit tests, unless we do this + try: + from pytest_cov.embed import cleanup_on_sigterm + except ImportError: + pass + else: + cleanup_on_sigterm() + + try: + mp.set_start_method( + "spawn", force=True + ) # set start method to 'spawn' BEFORE instantiating the queue and the event + except RuntimeError: + pass + + p = mp.Process(target=run_experiment_optimizer) + try: + p.start() + finally: + p.join() # necessary so that the Process exists before the test suite exits (thus coverage is collected) \ No newline at end of file diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..9622168 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,185 @@ +import pytest +import tensorflow as tf +from gradient_accumulator import GradientAccumulateOptimizer +from .utils import get_opt + +tf_version = int(tf.version.VERSION.split(".")[1]) + +tf.config.run_functions_eagerly(True) + +@pytest.fixture +def optimizer(): + opt = get_opt(opt_name="SGD", tf_version=tf_version) + return GradientAccumulateOptimizer(optimizer=opt, accum_steps=2) + +def test_learning_rate_getter(optimizer): + assert optimizer.learning_rate == 0.01 + +def test_learning_rate_setter(optimizer): + optimizer.learning_rate = 0.02 + assert optimizer.learning_rate == 0.02 + +def test_lr_getter(optimizer): + assert optimizer.lr == 0.01 + +def test_lr_setter(optimizer): + optimizer.lr = 0.02 + updated_lr = optimizer.lr.numpy() if hasattr(optimizer.lr, 'numpy') else optimizer.lr + + assert updated_lr == pytest.approx(0.02), "The lr getter did not return the updated learning rate." + + base_lr = optimizer.base_optimizer.learning_rate + base_lr = base_lr.numpy() if hasattr(base_lr, 'numpy') else base_lr + assert base_lr == pytest.approx(0.02), "The base_optimizer's learning rate was not updated." + + internal_lr = optimizer._learning_rate + internal_lr = internal_lr.numpy() if hasattr(internal_lr, 'numpy') else internal_lr + assert internal_lr == pytest.approx(0.02), "The internal _learning_rate attribute was not updated." + +def test__learning_rate(optimizer): + assert optimizer._learning_rate == 0.01 + optimizer.learning_rate = 0.02 + assert optimizer._learning_rate == 0.02 + +def test_step_getter(optimizer): + assert optimizer.step == 1 + +def test_step_setter(optimizer): + optimizer.step = 1 + assert optimizer.step == 1 + +def test_iterations_setter(optimizer): + optimizer.iterations = 1 + assert optimizer.iterations == 1 + +def test_optimizer_prop(optimizer): + assert optimizer.optimizer.__class__ == get_opt(opt_name="SGD", tf_version=tf_version).__class__ + +def test_reset_single_gradient(optimizer): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + optimizer.add_slot(var, "ga", initializer=tf.constant([3.0, 4.0])) + gradient = optimizer.get_slot(var, "ga") + optimizer._reset_single_gradient(gradient) + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))) + +def test_reset(optimizer): + var1 = tf.Variable([1.0, 2.0], dtype=tf.float32) + var2 = tf.Variable([3.0, 4.0], dtype=tf.float32) + optimizer.add_slot(var1, "ga", initializer=tf.constant([5.0, 6.0])) + optimizer.add_slot(var2, "ga", initializer=tf.constant([7.0, 8.0])) + for var in [var1, var2]: + gradient = optimizer.get_slot(var, "ga") + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == False + + optimizer.reset() + for var in [var1, var2]: + gradient = optimizer.get_slot(var, "ga") + assert tf.reduce_all(tf.equal(gradient, tf.zeros_like(gradient))).numpy() == True + + +@pytest.mark.parametrize("accum_steps", [1, 2]) +@pytest.mark.parametrize("use_agc", [True, False]) +def test_parse_grad(optimizer, use_agc, accum_steps): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + if accum_steps == 1: + expected_grad = tf.zeros_like(var) # gradients should not be applied yet + else: + expected_grad = tf.constant([3.0, 4.0]) + optimizer.add_slot(var, "ga", initializer=expected_grad) + accum_gradient = optimizer.get_slot(var, "ga") + + optimizer.use_agc = use_agc + optimizer.step.assign(accum_steps) + + parsed_grad = optimizer._parse_grad(accum_gradient, var) + assert tf.reduce_all(tf.equal(parsed_grad, expected_grad)).numpy() == True + + +@pytest.fixture +def optimizer_with_grads(optimizer): + opt = optimizer + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + + opt.add_slot(var, "ga", initializer=tf.constant([3.0, 4.0])) + + return opt, var + +def test_reset_accum_gradient_condition(optimizer_with_grads): + optimizer, var = optimizer_with_grads + + accum_grad = optimizer.get_slot(var, "ga") + accum_grad.assign(tf.constant([3.0, 4.0], dtype=tf.float32)) + + current_grad = tf.constant([3.0, 4.0], dtype=tf.float32) + + result_grad = optimizer.reset_accum_gradient(accum_grad, current_grad) + + expected_grad = tf.zeros_like(accum_grad) + + tf.debugging.assert_equal(result_grad, expected_grad, message="Gradients should be reset to zeros") + +@pytest.fixture +def optimizer_adam(): + opt = get_opt(opt_name="adam", tf_version=tf_version) + return GradientAccumulateOptimizer(optimizer=opt, accum_steps=2) + +@pytest.fixture +def optimizer_with_sparse_grads(optimizer_adam): + opt = optimizer_adam + var = tf.Variable(tf.zeros([10, 10]), dtype=tf.float32) + + opt.add_slot(var, "ga", initializer=tf.zeros_like(var)) + opt.add_slot(var, "m", initializer=tf.zeros_like(var)) + opt.add_slot(var, "v", initializer=tf.zeros_like(var)) + + return opt, var + +def test_resource_apply_sparse(optimizer_with_sparse_grads): + optimizer, var = optimizer_with_sparse_grads + + indices = tf.constant([0, 1], dtype=tf.int64) + updates = tf.constant([[0.1] * 10, [0.2] * 10], dtype=tf.float32) + + optimizer._reset_single_gradient(optimizer.get_slot(var, "ga")) + + grad = tf.IndexedSlices(values=updates, indices=indices, dense_shape=var.shape) + + optimizer._resource_apply_sparse(grad.values, var, grad.indices) + optimizer._resource_apply_sparse(grad.values, var, grad.indices) + + accumulated_grads = optimizer.get_slot(var, "ga") + expected_accumulated_grads = tf.scatter_nd(tf.expand_dims(indices, 1), updates * 2, var.shape) / optimizer.accum_steps + tf.debugging.assert_near(accumulated_grads, expected_accumulated_grads, atol=1e-5) + +def test_gradients_property(optimizer): + var = tf.Variable([1.0, 2.0], dtype=tf.float32) + + def loss_fn(): + return var[0]**2 + var[1]**2 + + with tf.GradientTape() as tape: + loss = loss_fn() + grads = tape.gradient(loss, [var]) + + optimizer.apply_gradients(zip(grads, [var])) + + accumulated_gradients = optimizer.gradients + + assert accumulated_gradients is not None, "Expected accumulated gradients to exist." + + +if __name__ == "__main__": + test__learning_rate(optimizer()) + test_learning_rate_getter(optimizer()) + test_learning_rate_setter(optimizer()) + test_lr_getter(optimizer()) + test_lr_setter(optimizer()) + test_step_getter(optimizer()) + test_step_setter(optimizer()) + test_optimizer_prop(optimizer()) + test_reset_single_gradient(optimizer()) + test_reset(optimizer()) + test_parse_grad(optimizer()) + test_reset_accum_gradient_condition(optimizer_with_grads()) + test_resource_apply_sparse(optimizer_with_sparse_grads()) + test_gradients_property(optimizer()) \ No newline at end of file