88from copy import copy
99from dataclasses import dataclass
1010from numbers import Number
11- from typing import Any , Dict , List , Optional , Set , Tuple , TypeVar , Union
11+ from typing import Any , Dict , Iterator , List , Optional , Set , Tuple , TypeVar , Union
1212
1313import numpy as np
1414import torch
@@ -99,6 +99,18 @@ class Statistics:
9999 uncalled_mods : "Set[str]"
100100
101101
102+ def _named_modules_without_dup (model : nn .Module ) -> Iterator [Tuple [str , nn .Module ]]:
103+ """
104+ Like .named_modules(), but the results are slightly different for
105+ some wrapped models.
106+ """
107+ seen = set ()
108+ for name , mod in _named_modules_with_dup (model ):
109+ if mod not in seen :
110+ seen .add (mod )
111+ yield name , mod
112+
113+
102114def _get_scoped_trace_graph (
103115 module : nn .Module ,
104116 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -139,7 +151,6 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
139151 tracing_state .pop_scope ()
140152 return outputs
141153
142- seen = set ()
143154 hook_handles : List [Any ] = []
144155
145156 def register_hooks (mod : nn .Module , name : str ) -> None :
@@ -149,31 +160,21 @@ def register_hooks(mod: nn.Module, name: str) -> None:
149160 hook_handles .append (prehook )
150161 hook_handles .append (posthook )
151162
152- # Torch script does not support parallel torch models, but we still
153- # want the scope names to be correct for the complete module.
163+ # Unwrap DDP, but correct the scope names for the root module.
154164 if isinstance (
155165 module , (nn .parallel .distributed .DistributedDataParallel , nn .DataParallel )
156166 ):
157-
158167 # Since DataParallel just wraps the model, add an extra set of hooks
159168 # to the model it wraps to account for the wrapper. Then trace it.
160169 root_name = aliases [module ]
161170 module = module .module
162171 register_hooks (module , root_name )
163172
164- # We don't need the duplication here, but self._model.named_modules()
165- # gives slightly different results for some wrapped models.
166- for name , mod in _named_modules_with_dup (module ):
167- if mod not in seen :
168- name = aliases [mod ]
169- register_hooks (mod , name )
170- seen .add (mod )
173+ for name , mod in _named_modules_without_dup (module ):
174+ name = aliases [mod ]
175+ register_hooks (mod , name )
171176
172- if hasattr (torch .jit , "get_trace_graph" ):
173- trace , _ = torch .jit .get_trace_graph (module , inputs )
174- graph = trace .graph ()
175- else :
176- graph , _ = _get_trace_graph (module , inputs )
177+ graph , _ = _get_trace_graph (module , inputs )
177178
178179 for handle in hook_handles :
179180 handle .remove ()
0 commit comments