@@ -737,11 +737,19 @@ def register_custom_support(
737737 grad_ckpt_context : Callable ,
738738 is_param_grad_enabled : Callable ,
739739 ) -> None :
740- """Register custom support for `AutoQuantize` score estimation.
740+ """(Optional) Register custom support for `AutoQuantize` score estimation.
741+
742+ This custom support is used to enable memory/compute efficient backward gradient propagation. This involves:
743+ - `grad_ckpt_context`: backward pass with gradient checkpointing enabled
744+ - `is_param_grad_enabled`: AutoQuantize only needs activation gradients to be computed (not weight
745+ gradients). `is_param_grad_enabled` is used to select which parameters should have gradients enabled,
746+ limiting gradient computation to only what's needed for activation gradients. For LLMs, to trigger all
747+ activation gradient computation, just enabling the embedding layer weight gradient is sufficient. This will
748+ enable gradient computation for all the activation gradients downstream.
741749
742750 If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
743751 used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
744- will be used to enable gradient for the parameter .
752+ will be used to select which parameters have gradients enabled to minimize gradient computation .
745753 """
746754 cls .custom_support .append ((is_supported_checker , grad_ckpt_context , is_param_grad_enabled ))
747755
@@ -793,10 +801,7 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
793801 output_diff -= output
794802 module .output_diff_dict [hparam ][recipe ] = output_diff .detach ()
795803
796- # Disable the configurable hparam so that they do not affect the any
797- # other hparam's score estimation
798- for hparam in module ._hparams_for_scoring :
799- if hparam .is_configurable :
804+ # Disable the configurable hparam now that we have computed the diff
800805 hparam .active = QuantRecipe (quant_cfg = None )
801806
802807 return output
0 commit comments