Skip to content

Commit d210561

Browse files
committed
minor updates
Signed-off-by: realAsma <[email protected]>
1 parent 0275c61 commit d210561

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

modelopt/torch/quantization/algorithms.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -737,11 +737,19 @@ def register_custom_support(
737737
grad_ckpt_context: Callable,
738738
is_param_grad_enabled: Callable,
739739
) -> None:
740-
"""Register custom support for `AutoQuantize` score estimation.
740+
"""(Optional) Register custom support for `AutoQuantize` score estimation.
741+
742+
This custom support is used to enable memory/compute efficient backward gradient propagation. This involves:
743+
- `grad_ckpt_context`: backward pass with gradient checkpointing enabled
744+
- `is_param_grad_enabled`: AutoQuantize only needs activation gradients to be computed (not weight
745+
gradients). `is_param_grad_enabled` is used to select which parameters should have gradients enabled,
746+
limiting gradient computation to only what's needed for activation gradients. For LLMs, to trigger all
747+
activation gradient computation, just enabling the embedding layer weight gradient is sufficient. This will
748+
enable gradient computation for all the activation gradients downstream.
741749
742750
If the `is_supported_checker(model)` returns True, the `grad_ckpt_context(model)` will be
743751
used to enable gradient checkpointing and `is_param_grad_enabled(pname, model)`
744-
will be used to enable gradient for the parameter.
752+
will be used to select which parameters have gradients enabled to minimize gradient computation.
745753
"""
746754
cls.custom_support.append((is_supported_checker, grad_ckpt_context, is_param_grad_enabled))
747755

@@ -793,10 +801,7 @@ def auto_quantize_score_estimate_forward(module, input, *args, **kwargs):
793801
output_diff -= output
794802
module.output_diff_dict[hparam][recipe] = output_diff.detach()
795803

796-
# Disable the configurable hparam so that they do not affect the any
797-
# other hparam's score estimation
798-
for hparam in module._hparams_for_scoring:
799-
if hparam.is_configurable:
804+
# Disable the configurable hparam now that we have computed the diff
800805
hparam.active = QuantRecipe(quant_cfg=None)
801806

802807
return output

tests/unit/torch/quantization/plugins/test_huggingface.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ def test_autoquantize_huggingface():
160160
verbose=True,
161161
)
162162

163-
print(search_history, model)
164-
165163

166164
@pytest.mark.parametrize(
167165
("model_cls", "quant_config"),

0 commit comments

Comments
 (0)