@@ -111,6 +111,22 @@ def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Modul
111111 yield name , mod
112112
113113
114+ def _maybe_flatten (object ) -> List [torch .Tensor ]:
115+ # Try its best to find all tensors within the object and put them
116+ # into a flattened list. Custom stuctures cannot be recognized.
117+ # TODO: improve coverage of other structures, e.g. by using __dict__
118+ ret = []
119+ if isinstance (object , torch .Tensor ):
120+ ret .append (object )
121+ if isinstance (object , (list , tuple )):
122+ for x in object :
123+ ret .extend (_maybe_flatten (x ))
124+ if isinstance (object , dict ):
125+ for x in object .values ():
126+ ret .extend (_maybe_flatten (x ))
127+ return ret
128+
129+
114130def _get_scoped_trace_graph (
115131 module : nn .Module ,
116132 inputs : Union [Tensor , Tuple [Tensor , ...]],
@@ -149,12 +165,14 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
149165 tracing_state = torch ._C ._get_tracing_state ()
150166 if tracing_state :
151167 tracing_state .pop_scope ()
168+ # Don't save all intermediate tensors on GPU. There could be a lot.
169+ all_output_tensors .extend ([x .cpu () for x in _maybe_flatten (outputs )])
152170 return outputs
153171
172+ all_output_tensors : List [torch .Tensor ] = []
154173 hook_handles : List [Any ] = []
155174
156- def register_hooks (mod : nn .Module , name : str ) -> None :
157- # pyre-fixme[29]: `Union[Tensor, nn.Module]` is not a function.
175+ def register_hooks (mod , name : str ) -> None :
158176 prehook = mod .register_forward_pre_hook (ScopePushHook (name ))
159177 posthook = mod .register_forward_hook (ScopePopHook ())
160178 hook_handles .append (prehook )
@@ -174,6 +192,27 @@ def register_hooks(mod: nn.Module, name: str) -> None:
174192 name = aliases [mod ]
175193 register_hooks (mod , name )
176194
195+ class WrapperModule (nn .Module ):
196+ def __init__ (self , module ):
197+ super ().__init__ ()
198+ self ._wrapped = module
199+
200+ def forward (self , * args ):
201+ # Some intermediate tensors may not be directly connected to the final model
202+ # output, for example due to:
203+ # * control flow not observed by tracing
204+ # * tensor -> numpy/int conversion
205+ # Operations that produce such tensors will get pruned by pytorch's DCE,
206+ # but we want to include them in the graph.
207+ # There is currently no way to disable DCE. So we capture all tensors we can
208+ # and return them here, to reduce missing flops.
209+ outputs = self ._wrapped (* args )
210+ return outputs , all_output_tensors
211+
212+ # Hooks are registered before wrapping with their original scope names, so
213+ # adding a wrapper here won't affect scopes.
214+ module = WrapperModule (module )
215+
177216 graph , _ = _get_trace_graph (module , inputs )
178217
179218 for handle in hook_handles :
0 commit comments