5555from megatron .core .transformer .moe .shared_experts import SharedExpertMLP
5656from megatron .core .transformer .transformer_layer import TransformerLayer
5757
58- from modelopt .torch .nas .modules import DynamicModuleList
5958from modelopt .torch .opt .dynamic import DynamicModule
6059from modelopt .torch .opt .hparam import HPType
6160from modelopt .torch .opt .searcher import ConstraintsDict
7776 ConstraintsRes ,
7877)
7978from ..hparams .concat import build_concat_hp
80- from ..modules import _DynamicLayerNorm
79+ from ..modules import DynamicModuleList , _DynamicLayerNorm
8180from ..modules .utils import get_sliced_tensor , get_sliced_tensor_by_slices
8281from ..registry import DMRegistry
8382from ..search_space import SampleFunc
8483from ..traced_hp import TracedHp
84+ from .megatron_hooks import MegatronL2NormHook
8585
8686SUPPORTED_MODELS = {GPTModel : "megatron.core.models.gpt.GPTModel" }
8787
@@ -265,39 +265,19 @@ def _setup(self):
265265 # can be discarded.
266266 # This limitation might be fixed in OMNIML-180 (Flexible Importance Estimator)
267267 # where we separate the importance estimation from the dynamic module.
268- self ._register_temp_attribute ("_activations" , None )
269- self .hook_handle = self .linear_fc2 .register_forward_hook (self ._linear_fc2_forward_hook )
268+ max_ffn_size = int (self .get_hparam (self .hparam_name ).max ) # type: ignore[arg-type]
269+ activation_hook = MegatronL2NormHook (max_size = max_ffn_size )
270+ self ._register_temp_attribute ("_activation_hook" , activation_hook )
271+ # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
272+ self .hook_handle = self .linear_fc2 .register_forward_hook (activation_hook )
270273 ffn_hidden_size .register_importance (self ._estimate_importance )
271274
272- def _linear_fc2_forward_hook (self , module , input , output ):
273- """Hook to collect activations for importance estimation.
274-
275- Activations are computed as mean over seq_len and then squared and summed over batch_size.
276- Later we take the square root of the sum to get the L2 norm.
277- """
278- # Gather input [seq_len, batch_size, ffn_hidden_size] over all TP regions
279- # NOTE: This is not used at the moment since we restrict to TP=1
280- input = gather_from_tensor_model_parallel_region (input [0 ]).detach ()
281- if input .dim () == 2 :
282- # For sparse experts, there is no batch dimension.
283- input = input [:, None , :]
284- # Dont aggregate activations from non-max subnets (e.g. from profiling)
285- if input .shape [- 1 ] != self .get_hparam (self .hparam_name ).max :
286- return
287-
288- input = input .to (torch .float32 ) # use full precision to avoid overflow
289- activations = input .abs ().mean (dim = 0 ) # [batch_size, ffn_hidden_size]
290- activations = activations .pow (2 ).sum (dim = 0 ) # [ffn_hidden_size]
291- if self ._activations is None :
292- self ._activations = activations
293- else :
294- self ._activations += activations
295-
296275 def _estimate_importance (self ) -> TracedHp .Importance :
297276 """Return the activation magnitude-based importance of the ffn_hidden_size."""
298- assert self ._activations is not None , "No activations collected for importance estimation."
299- # Convert squared sum to L2 norm
300- return self ._activations .pow (0.5 )
277+ assert self ._activation_hook ._activations is not None , (
278+ "No activations collected for importance estimation."
279+ )
280+ return self ._activation_hook .accumulate ()
301281
302282 def set_hidden_size_hp (self , hidden_size : TracedHp ) -> None :
303283 """Set hidden size for shared expert."""
@@ -612,46 +592,26 @@ def _setup(self):
612592 )
613593
614594 # register importance estimator for linear_qkv.output_size and linear_proj.input_size
615- self ._register_temp_attribute ("_activations" , None )
616- self .hook_handle = self .linear_proj .register_forward_hook (self ._linear_proj_forward_hook )
595+ num_heads_per_group_max = int (self .get_hparam ("num_heads_per_group" ).max ) # type: ignore[arg-type]
596+ num_query_groups_max = int (self .get_hparam ("num_query_groups" ).max ) # type: ignore[arg-type]
597+ max_size = num_heads_per_group_max * num_query_groups_max * self .config .kv_channels
598+ activation_hook = MegatronL2NormHook (max_size = max_size )
599+ self ._register_temp_attribute ("_activation_hook" , activation_hook )
600+ # TODO: confusion: why hook_handle is removed manually in export() and not using _register_temp_attribute?
601+ self .hook_handle = self .linear_proj .register_forward_hook (activation_hook )
617602 # NOTE: num_heads_per_group's slice_order will be of length num_attention_heads to be able to sort heads,
618603 # otherwise we would only have aggregated importance of heads per group.
619604 # While enforcing order during `sort_parameters`, we dont check the shape of the slice_order
620605 num_heads_per_group .register_importance (self ._estimate_all_head_importance )
621606 num_query_groups .register_importance (self ._estimate_query_group_importance )
622607
623- def _linear_proj_forward_hook (self , module , input , output ):
624- """Hook to collect activations for importance estimation.
625-
626- Activations are computed as mean over seq_len and then squared and summed over batch_size.
627- Later we take the square root of the sum to get the L2 norm.
628- """
629- # Gather input [seq_len, batch_size, query_projection_size] over all TP regions
630- # NOTE: This is not used at the moment since we restrict to TP=1
631- input = gather_from_tensor_model_parallel_region (input [0 ]).detach ()
632-
633- # Dont aggregate activations from non-max subnets (e.g. from profiling)
634- if (
635- input .shape [- 1 ]
636- != self .get_hparam ("num_heads_per_group" ).max
637- * self .get_hparam ("num_query_groups" ).max
638- * self .config .kv_channels
639- ):
640- return
641-
642- input = input .to (torch .float32 ) # use full precision to avoid overflow
643- activations = input .abs ().mean (dim = 0 )
644- activations = activations .pow (2 ).sum (dim = 0 ) # [query_projection_size]
645- if self ._activations is None :
646- self ._activations = activations
647- else :
648- self ._activations += activations
649-
650608 def _estimate_all_head_importance (self ) -> TracedHp .Importance :
651609 """Return the importance for num_attention_heads (num_heads_per_group * num_query_groups)."""
652- assert self ._activations is not None , "No activations collected for importance estimation."
610+ assert self ._activation_hook ._activations is not None , (
611+ "No activations collected for importance estimation."
612+ )
653613 # Convert squared sum to L2 norm
654- scores = self ._activations . pow ( 0.5 )
614+ scores = self ._activation_hook . accumulate ( )
655615 attn_head_importance = torch .linalg .vector_norm (
656616 scores .view (
657617 self .get_hparam ("num_heads_per_group" ).max
@@ -665,9 +625,11 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
665625
666626 def _estimate_query_group_importance (self ) -> TracedHp .Importance :
667627 """Return the importance of the ``num_query_groups`` hparam."""
668- assert self ._activations is not None , "No activations collected for importance estimation."
628+ assert self ._activation_hook ._activations is not None , (
629+ "No activations collected for importance estimation."
630+ )
669631 # Convert squared sum to L2 norm
670- scores = self ._activations . pow ( 0.5 )
632+ scores = self ._activation_hook . accumulate ( )
671633 group_importance = torch .linalg .vector_norm (
672634 scores .view (
673635 self .get_hparam ("num_heads_per_group" ).max ,
@@ -1594,8 +1556,11 @@ def get_activations_and_layer_scores(
15941556 """Get the per-rank activations and layer scores from the module."""
15951557 local_activations = {}
15961558 for n , m in self .named_modules ():
1559+ # TODO: Remove legacy _activations check once all modules use _activation_hook
15971560 if hasattr (m , "_activations" ):
15981561 local_activations [n ] = m ._activations
1562+ elif hasattr (m , "_activation_hook" ):
1563+ local_activations [n ] = m ._activation_hook ._activations
15991564 activations_per_rank = dist .allgather (
16001565 local_activations , group = get_pipeline_model_parallel_group ()
16011566 )
@@ -1624,8 +1589,11 @@ def set_activations_and_layer_scores(
16241589 for layer in self .decoder .layers :
16251590 layer ._scores = layer_scores [layer .layer_number ]
16261591 for n , m in self .named_modules ():
1592+ # TODO: Remove legacy _activations check once all modules use _activation_hook
16271593 if hasattr (m , "_activations" ):
16281594 m ._activations = activations_per_rank [rank ][n ]
1595+ elif hasattr (m , "_activation_hook" ):
1596+ m ._activation_hook ._activations = activations_per_rank [rank ][n ]
16291597
16301598
16311599def drop_mcore_language_model_layers (model : nn .Module , * , layers_to_drop : list [int ]) -> None :
0 commit comments