33Hacked together by / Copyright 2020 Ross Wightman
44"""
55from .model_ema import ModelEma
6-
6+ import torch
7+ import fnmatch
78
89def unwrap_model (model ):
910 if isinstance (model , ModelEma ):
@@ -14,3 +15,78 @@ def unwrap_model(model):
1415
1516def get_state_dict (model , unwrap_fn = unwrap_model ):
1617 return unwrap_fn (model ).state_dict ()
18+
19+
20+ def avg_sq_ch_mean (model , input , output ):
21+ "calculate average channel square mean of output activations"
22+ return torch .mean (output .mean (axis = [0 ,2 ,3 ])** 2 ).item ()
23+
24+
25+ def avg_ch_var (model , input , output ):
26+ "calculate average channel variance of output activations"
27+ return torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()\
28+
29+
30+ def avg_ch_var_residual (model , input , output ):
31+ "calculate average channel variance of output activations"
32+ return torch .mean (output .var (axis = [0 ,2 ,3 ])).item ()
33+
34+
35+ class ActivationStatsHook :
36+ """Iterates through each of `model`'s modules and matches modules using unix pattern
37+ matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
38+ a match.
39+
40+ Arguments:
41+ model (nn.Module): model from which we will extract the activation stats
42+ hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
43+ matching with the name of model's modules.
44+ hook_fns (List[Callable]): List of hook functions to be registered at every
45+ module in `layer_names`.
46+
47+ Inspiration from https://docs.fast.ai/callback.hook.html.
48+
49+ Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
50+ on how to plot Signal Propogation Plots using `ActivationStatsHook`.
51+ """
52+
53+ def __init__ (self , model , hook_fn_locs , hook_fns ):
54+ self .model = model
55+ self .hook_fn_locs = hook_fn_locs
56+ self .hook_fns = hook_fns
57+ if len (hook_fn_locs ) != len (hook_fns ):
58+ raise ValueError ("Please provide `hook_fns` for each `hook_fn_locs`, \
59+ their lengths are different." )
60+ self .stats = dict ((hook_fn .__name__ , []) for hook_fn in hook_fns )
61+ for hook_fn_loc , hook_fn in zip (hook_fn_locs , hook_fns ):
62+ self .register_hook (hook_fn_loc , hook_fn )
63+
64+ def _create_hook (self , hook_fn ):
65+ def append_activation_stats (module , input , output ):
66+ out = hook_fn (module , input , output )
67+ self .stats [hook_fn .__name__ ].append (out )
68+ return append_activation_stats
69+
70+ def register_hook (self , hook_fn_loc , hook_fn ):
71+ for name , module in self .model .named_modules ():
72+ if not fnmatch .fnmatch (name , hook_fn_loc ):
73+ continue
74+ module .register_forward_hook (self ._create_hook (hook_fn ))
75+
76+
77+ def extract_spp_stats (model ,
78+ hook_fn_locs ,
79+ hook_fns ,
80+ input_shape = [8 , 3 , 224 , 224 ]):
81+ """Extract average square channel mean and variance of activations during
82+ forward pass to plot Signal Propogation Plots (SPP).
83+
84+ Paper: https://arxiv.org/abs/2101.08692
85+
86+ Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
87+ """
88+ x = torch .normal (0. , 1. , input_shape )
89+ hook = ActivationStatsHook (model , hook_fn_locs = hook_fn_locs , hook_fns = hook_fns )
90+ _ = model (x )
91+ return hook .stats
92+
0 commit comments