Skip to content

Commit 8c9cdd4

Browse files
Add L2NormHook and use it in megatron.py (#599)
## What does this PR do? - Add L2NormHook and use it in megatron.py - Using L2NormHook removes code duplication between _DynamicSelfAttention and _DynamicMLP This is the first step towards reusing activation scores logic across Minitron and Puzzle. Next steps: - complete redesign of megatron.py - move other activation hooks logic to hooks.py - then combined those hooks.py with a similar hooks.py functoriality in puzzle (modelopt/torch/_compress/activation_scoring/activation_hooks/hooks.py) Questions: - why in the code before and after this redesign we store temp variables in two ways _register_temp_attribute and self.hook_handle)? ``` self._register_temp_attribute("_activation_hook", activation_hook) # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute? self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook) ``` --------- Signed-off-by: Daniel Korzekwa <[email protected]> Signed-off-by: Daniel Korzekwa <[email protected]> Co-authored-by: Keval Morabia <[email protected]>
1 parent 194b532 commit 8c9cdd4

File tree

3 files changed

+184
-64
lines changed

3 files changed

+184
-64
lines changed

modelopt/torch/nas/plugins/megatron.py

Lines changed: 32 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
5656
from megatron.core.transformer.transformer_layer import TransformerLayer
5757

58-
from modelopt.torch.nas.modules import DynamicModuleList
5958
from modelopt.torch.opt.dynamic import DynamicModule
6059
from modelopt.torch.opt.hparam import HPType
6160
from modelopt.torch.opt.searcher import ConstraintsDict
@@ -77,11 +76,12 @@
7776
ConstraintsRes,
7877
)
7978
from ..hparams.concat import build_concat_hp
80-
from ..modules import _DynamicLayerNorm
79+
from ..modules import DynamicModuleList, _DynamicLayerNorm
8180
from ..modules.utils import get_sliced_tensor, get_sliced_tensor_by_slices
8281
from ..registry import DMRegistry
8382
from ..search_space import SampleFunc
8483
from ..traced_hp import TracedHp
84+
from .megatron_hooks import MegatronL2NormHook
8585

8686
SUPPORTED_MODELS = {GPTModel: "megatron.core.models.gpt.GPTModel"}
8787

@@ -265,39 +265,19 @@ def _setup(self):
265265
# can be discarded.
266266
# This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator)
267267
# where we separate the importance estimation from the dynamic module.
268-
self._register_temp_attribute("_activations", None)
269-
self.hook_handle = self.linear_fc2.register_forward_hook(self._linear_fc2_forward_hook)
268+
max_ffn_size = int(self.get_hparam(self.hparam_name).max) # type: ignore[arg-type]
269+
activation_hook = MegatronL2NormHook(max_size=max_ffn_size)
270+
self._register_temp_attribute("_activation_hook", activation_hook)
271+
# TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
272+
self.hook_handle = self.linear_fc2.register_forward_hook(activation_hook)
270273
ffn_hidden_size.register_importance(self._estimate_importance)
271274

272-
def _linear_fc2_forward_hook(self, module, input, output):
273-
"""Hook to collect activations for importance estimation.
274-
275-
Activations are computed as mean over seq_len and then squared and summed over batch_size.
276-
Later we take the square root of the sum to get the L2 norm.
277-
"""
278-
# Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions
279-
# NOTE: This is not used at the moment since we restrict to TP=1
280-
input = gather_from_tensor_model_parallel_region(input[0]).detach()
281-
if input.dim() == 2:
282-
# For sparse experts, there is no batch dimension.
283-
input = input[:, None, :]
284-
# Dont aggregate activations from non-max subnets (e.g. from profiling)
285-
if input.shape[-1] != self.get_hparam(self.hparam_name).max:
286-
return
287-
288-
input = input.to(torch.float32) # use full precision to avoid overflow
289-
activations = input.abs().mean(dim=0) # [batch_size, ffn_hidden_size]
290-
activations = activations.pow(2).sum(dim=0) # [ffn_hidden_size]
291-
if self._activations is None:
292-
self._activations = activations
293-
else:
294-
self._activations += activations
295-
296275
def _estimate_importance(self) -> TracedHp.Importance:
297276
"""Return the activation magnitude-based importance of the ffn_hidden_size."""
298-
assert self._activations is not None, "No activations collected for importance estimation."
299-
# Convert squared sum to L2 norm
300-
return self._activations.pow(0.5)
277+
assert self._activation_hook._activations is not None, (
278+
"No activations collected for importance estimation."
279+
)
280+
return self._activation_hook.accumulate()
301281

302282
def set_hidden_size_hp(self, hidden_size: TracedHp) -> None:
303283
"""Set hidden size for shared expert."""
@@ -612,46 +592,26 @@ def _setup(self):
612592
)
613593

614594
# register importance estimator for linear_qkv.output_size and linear_proj.input_size
615-
self._register_temp_attribute("_activations", None)
616-
self.hook_handle = self.linear_proj.register_forward_hook(self._linear_proj_forward_hook)
595+
num_heads_per_group_max = int(self.get_hparam("num_heads_per_group").max) # type: ignore[arg-type]
596+
num_query_groups_max = int(self.get_hparam("num_query_groups").max) # type: ignore[arg-type]
597+
max_size = num_heads_per_group_max * num_query_groups_max * self.config.kv_channels
598+
activation_hook = MegatronL2NormHook(max_size=max_size)
599+
self._register_temp_attribute("_activation_hook", activation_hook)
600+
# TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
601+
self.hook_handle = self.linear_proj.register_forward_hook(activation_hook)
617602
# NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
618603
# otherwise we would only have aggregated importance of heads per group.
619604
# While enforcing order during `sort_parameters`, we dont check the shape of the slice_order
620605
num_heads_per_group.register_importance(self._estimate_all_head_importance)
621606
num_query_groups.register_importance(self._estimate_query_group_importance)
622607

623-
def _linear_proj_forward_hook(self, module, input, output):
624-
"""Hook to collect activations for importance estimation.
625-
626-
Activations are computed as mean over seq_len and then squared and summed over batch_size.
627-
Later we take the square root of the sum to get the L2 norm.
628-
"""
629-
# Gather input [seq_len, batch_size, query_projection_size] over all TP regions
630-
# NOTE: This is not used at the moment since we restrict to TP=1
631-
input = gather_from_tensor_model_parallel_region(input[0]).detach()
632-
633-
# Dont aggregate activations from non-max subnets (e.g. from profiling)
634-
if (
635-
input.shape[-1]
636-
!= self.get_hparam("num_heads_per_group").max
637-
* self.get_hparam("num_query_groups").max
638-
* self.config.kv_channels
639-
):
640-
return
641-
642-
input = input.to(torch.float32) # use full precision to avoid overflow
643-
activations = input.abs().mean(dim=0)
644-
activations = activations.pow(2).sum(dim=0) # [query_projection_size]
645-
if self._activations is None:
646-
self._activations = activations
647-
else:
648-
self._activations += activations
649-
650608
def _estimate_all_head_importance(self) -> TracedHp.Importance:
651609
"""Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
652-
assert self._activations is not None, "No activations collected for importance estimation."
610+
assert self._activation_hook._activations is not None, (
611+
"No activations collected for importance estimation."
612+
)
653613
# Convert squared sum to L2 norm
654-
scores = self._activations.pow(0.5)
614+
scores = self._activation_hook.accumulate()
655615
attn_head_importance = torch.linalg.vector_norm(
656616
scores.view(
657617
self.get_hparam("num_heads_per_group").max
@@ -665,9 +625,11 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
665625

666626
def _estimate_query_group_importance(self) -> TracedHp.Importance:
667627
"""Return the importance of the ``num_query_groups`` hparam."""
668-
assert self._activations is not None, "No activations collected for importance estimation."
628+
assert self._activation_hook._activations is not None, (
629+
"No activations collected for importance estimation."
630+
)
669631
# Convert squared sum to L2 norm
670-
scores = self._activations.pow(0.5)
632+
scores = self._activation_hook.accumulate()
671633
group_importance = torch.linalg.vector_norm(
672634
scores.view(
673635
self.get_hparam("num_heads_per_group").max,
@@ -1594,8 +1556,11 @@ def get_activations_and_layer_scores(
15941556
"""Get the per-rank activations and layer scores from the module."""
15951557
local_activations = {}
15961558
for n, m in self.named_modules():
1559+
# TODO: Remove legacy _activations check once all modules use _activation_hook
15971560
if hasattr(m, "_activations"):
15981561
local_activations[n] = m._activations
1562+
elif hasattr(m, "_activation_hook"):
1563+
local_activations[n] = m._activation_hook._activations
15991564
activations_per_rank = dist.allgather(
16001565
local_activations, group=get_pipeline_model_parallel_group()
16011566
)
@@ -1624,8 +1589,11 @@ def set_activations_and_layer_scores(
16241589
for layer in self.decoder.layers:
16251590
layer._scores = layer_scores[layer.layer_number]
16261591
for n, m in self.named_modules():
1592+
# TODO: Remove legacy _activations check once all modules use _activation_hook
16271593
if hasattr(m, "_activations"):
16281594
m._activations = activations_per_rank[rank][n]
1595+
elif hasattr(m, "_activation_hook"):
1596+
m._activation_hook._activations = activations_per_rank[rank][n]
16291597

16301598

16311599
def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[int]) -> None:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Forward hooks for activation-based importance estimation (megatron NAS plugin)."""
16+
17+
from abc import ABC, abstractmethod
18+
19+
import torch
20+
from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region
21+
from torch import nn
22+
23+
24+
class ForwardHook(ABC):
25+
"""Base class for PyTorch forward hooks.
26+
27+
This follows the PyTorch forward hook API where the second
28+
parameter is 'args' (a tuple of positional arguments passed to forward()).
29+
30+
Usage:
31+
hook = MyHook()
32+
module.register_forward_hook(hook)
33+
"""
34+
35+
@abstractmethod
36+
def __call__(
37+
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
38+
) -> None:
39+
"""Forward hook that is called after the module's forward pass.
40+
41+
Args:
42+
module: The module this hook is registered on
43+
args: Tuple of positional arguments passed to module.forward()
44+
output: The output from module.forward()
45+
46+
Returns:
47+
None (does not modify the output)
48+
"""
49+
...
50+
51+
52+
class MegatronL2NormHook(ForwardHook):
53+
"""Hook for accumulating activation statistics for importance estimation.
54+
55+
Activations are computed as mean over seq_len and then squared and summed over batch_size.
56+
In the accumulate() method we take the square root of the sum to get the L2 norm.
57+
58+
Args:
59+
max_size: Optional maximum expected size to validate against (skips if mismatch).
60+
Useful for skipping non-max subnets during profiling.
61+
"""
62+
63+
def __init__(self, max_size: int | None = None):
64+
"""Initialize the L2NormHook."""
65+
self.max_size = max_size
66+
self._activations: torch.Tensor | None = None
67+
68+
def __call__(
69+
self, module: nn.Module, args: tuple[torch.Tensor, ...], output: torch.Tensor
70+
) -> None:
71+
"""Accumulate activation statistics from the forward pass."""
72+
# Gather input [seq_len, batch_size, hidden_size] over all TP regions
73+
# NOTE: This is not used at the moment since we restrict to TP=1
74+
input_tensor = gather_from_tensor_model_parallel_region(args[0]).detach()
75+
76+
if input_tensor.dim() == 2:
77+
# For sparse experts, there is no batch dimension.
78+
input_tensor = input_tensor[:, None, :]
79+
80+
# Dont aggregate activations from non-max subnets (e.g. from profiling)
81+
if self.max_size is not None and input_tensor.shape[-1] != self.max_size:
82+
return
83+
84+
input_tensor = input_tensor.to(torch.float32) # use full precision to avoid overflow
85+
activations = input_tensor.abs().mean(dim=0) # [batch_size, hidden_size]
86+
activations = activations.pow(2).sum(dim=0) # [hidden_size]
87+
88+
if self._activations is None:
89+
self._activations = activations
90+
else:
91+
self._activations += activations
92+
93+
def accumulate(self) -> torch.Tensor:
94+
"""Return the accumulated L2 norm of activations.
95+
96+
Returns:
97+
Tensor of accumulated scores, one per channel
98+
99+
Raises:
100+
AssertionError: If no activations have been collected yet
101+
"""
102+
assert self._activations is not None, "No activations collected for importance estimation."
103+
# Convert squared sum to L2 norm
104+
return self._activations.pow(0.5)

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,12 @@ def _get_model(initialize_megatron=True):
8787
normalization=normalization,
8888
num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage,
8989
num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage,
90+
use_cpu_initialization=True, # Ensure deterministic weight init across CUDA versions
9091
).cuda()
9192
return model
9293

9394
model = _get_model()
95+
9496
sd = model.state_dict()
9597

9698
def forward_loop(m):
@@ -134,6 +136,52 @@ def forward_loop(m):
134136
assert pruning_scores["layer_scores"]
135137
assert pruning_scores["activations_per_rank"]
136138

139+
# TODO: Simplify it: this unit test is too long,
140+
# hard to read (the same set of assertions across different test cases with if-else).
141+
142+
assert len(pruning_scores["activations_per_rank"]) == 1
143+
rank_0_activations = pruning_scores["activations_per_rank"][0]
144+
145+
# Test case 1: MHA - pruned ffn/4 (num_attention_heads=8, num_query_groups=8, ffn_div=4)
146+
if pruned_ffn_div == 4:
147+
# Layer scores
148+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.0868452191352844, abs=1e-3)
149+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7638601660728455, abs=1e-3)
150+
151+
# Validate decoder.layers.0.mlp activations
152+
mlp_0_acts = rank_0_activations["decoder.layers.0.mlp"]
153+
assert mlp_0_acts.min().item() == pytest.approx(0.0015609927941114, abs=1e-3)
154+
assert mlp_0_acts.max().item() == pytest.approx(0.3844809532165527, abs=1e-3)
155+
assert mlp_0_acts.mean().item() == pytest.approx(0.0629318505525589, abs=1e-3)
156+
157+
# Validate decoder.layers.1.mlp activations
158+
mlp_1_acts = rank_0_activations["decoder.layers.1.mlp"]
159+
assert mlp_1_acts.min().item() == pytest.approx(0.0001484956446802, abs=1e-3)
160+
assert mlp_1_acts.max().item() == pytest.approx(0.7835369110107422, abs=1e-3)
161+
assert mlp_1_acts.mean().item() == pytest.approx(0.0926810950040817, abs=1e-3)
162+
163+
# Test case 2: GQA - pruned attention/2 (num_attention_heads=8, num_query_groups=4, attention_div=2)
164+
elif pruned_num_attention_heads_div == 2 and pruned_ffn_div == 1:
165+
# Layer scores
166+
assert pruning_scores["layer_scores"][1] == pytest.approx(2.1415508985519409, abs=1e-3)
167+
assert pruning_scores["layer_scores"][2] == pytest.approx(1.7198008894920349, abs=1e-3)
168+
169+
# Validate decoder.layers.0.self_attention activations
170+
assert "decoder.layers.0.self_attention" in rank_0_activations
171+
attn_0_acts = rank_0_activations["decoder.layers.0.self_attention"]
172+
assert attn_0_acts.shape == torch.Size([256])
173+
assert attn_0_acts.min().item() == pytest.approx(0.0409194342792034, abs=1e-3)
174+
assert attn_0_acts.max().item() == pytest.approx(0.5261313319206238, abs=1e-3)
175+
assert attn_0_acts.mean().item() == pytest.approx(0.1613342612981796, abs=1e-3)
176+
177+
# Validate decoder.layers.1.self_attention activations
178+
assert "decoder.layers.1.self_attention" in rank_0_activations
179+
attn_1_acts = rank_0_activations["decoder.layers.1.self_attention"]
180+
assert attn_1_acts.shape == torch.Size([256])
181+
assert attn_1_acts.min().item() == pytest.approx(0.1189328655600548, abs=1e-3)
182+
assert attn_1_acts.max().item() == pytest.approx(1.3832759857177734, abs=1e-3)
183+
assert attn_1_acts.mean().item() == pytest.approx(0.4782669544219971, abs=1e-3)
184+
137185
# Assert weights are pruned correctly
138186
for layer in model.decoder.layers:
139187
assert layer.mlp.linear_fc1.weight.shape == (

0 commit comments

Comments
 (0)