Skip to content

Commit f38f5f0

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
flop count: capture output tensors from all layers so that unused layers are correctly counted.
Reviewed By: ericmintun Differential Revision: D32242109 fbshipit-source-id: 362b68a2b7c50b1ec2efd0d415d7ec3e6b2ba1c8
1 parent c1ae9e7 commit f38f5f0

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

fvcore/nn/jit_analysis.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Modul
111111
yield name, mod
112112

113113

114+
def _maybe_flatten(object) -> List[torch.Tensor]:
115+
# Try its best to find all tensors within the object and put them
116+
# into a flattened list. Custom stuctures cannot be recognized.
117+
# TODO: improve coverage of other structures, e.g. by using __dict__
118+
ret = []
119+
if isinstance(object, torch.Tensor):
120+
ret.append(object)
121+
if isinstance(object, (list, tuple)):
122+
for x in object:
123+
ret.extend(_maybe_flatten(x))
124+
if isinstance(object, dict):
125+
for x in object.values():
126+
ret.extend(_maybe_flatten(x))
127+
return ret
128+
129+
114130
def _get_scoped_trace_graph(
115131
module: nn.Module,
116132
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -149,12 +165,14 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
149165
tracing_state = torch._C._get_tracing_state()
150166
if tracing_state:
151167
tracing_state.pop_scope()
168+
# Don't save all intermediate tensors on GPU. There could be a lot.
169+
all_output_tensors.extend([x.cpu() for x in _maybe_flatten(outputs)])
152170
return outputs
153171

172+
all_output_tensors: List[torch.Tensor] = []
154173
hook_handles: List[Any] = []
155174

156-
def register_hooks(mod: nn.Module, name: str) -> None:
157-
# pyre-fixme[29]: `Union[Tensor, nn.Module]` is not a function.
175+
def register_hooks(mod, name: str) -> None:
158176
prehook = mod.register_forward_pre_hook(ScopePushHook(name))
159177
posthook = mod.register_forward_hook(ScopePopHook())
160178
hook_handles.append(prehook)
@@ -174,6 +192,27 @@ def register_hooks(mod: nn.Module, name: str) -> None:
174192
name = aliases[mod]
175193
register_hooks(mod, name)
176194

195+
class WrapperModule(nn.Module):
196+
def __init__(self, module):
197+
super().__init__()
198+
self._wrapped = module
199+
200+
def forward(self, *args):
201+
# Some intermediate tensors may not be directly connected to the final model
202+
# output, for example due to:
203+
# * control flow not observed by tracing
204+
# * tensor -> numpy/int conversion
205+
# Operations that produce such tensors will get pruned by pytorch's DCE,
206+
# but we want to include them in the graph.
207+
# There is currently no way to disable DCE. So we capture all tensors we can
208+
# and return them here, to reduce missing flops.
209+
outputs = self._wrapped(*args)
210+
return outputs, all_output_tensors
211+
212+
# Hooks are registered before wrapping with their original scope names, so
213+
# adding a wrapper here won't affect scopes.
214+
module = WrapperModule(module)
215+
177216
graph, _ = _get_trace_graph(module, inputs)
178217

179218
for handle in hook_handles:

tests/test_jit_model_analysis.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import unittest
77
import warnings
88
from collections import Counter
9-
from typing import Any, Dict, List
9+
from typing import Any, Dict, List, Union
1010

1111
import torch
1212
import torch.nn as nn
1313
from fvcore.nn.flop_count import FlopCountAnalysis
1414
from fvcore.nn.jit_analysis import JitModelAnalysis
1515
from fvcore.nn.jit_handles import Handle, addmm_flop_jit, conv_flop_jit, linear_flop_jit
16+
from torch.nn import functional as F
1617

1718

1819
class NestedNetInnerModule(nn.Module):
@@ -283,20 +284,28 @@ class TraceWarningNet(nn.Module):
283284
will be skipped and raise a warning.
284285
"""
285286

287+
class IntLinear(nn.Linear):
288+
"""
289+
A linear that outputs int, therefore cannot be traced.
290+
"""
291+
292+
def forward(self, x) -> Union[float, int]:
293+
return F.linear(x, self.weight, self.bias).item()
294+
286295
def __init__(self) -> None:
287296
super().__init__()
288297
self.input_size = (10,)
289298
fc1_in, fc1_out = 10, 1
290299
fc2_in, fc2_out = 10, 10
291300

292-
self.fc1 = nn.Linear(in_features=fc1_in, out_features=fc1_out)
301+
self.fc1 = TraceWarningNet.IntLinear(in_features=fc1_in, out_features=fc1_out)
293302
self.fc2 = nn.Linear(in_features=fc2_in, out_features=fc2_out)
294303

295304
self.fc1_flops = fc1_in * fc1_out # type: int
296305
self.fc2_flops = fc2_in * fc2_out # type: int
297306

298307
def forward(self, x: torch.Tensor) -> torch.Tensor:
299-
y = self.fc1(x).item()
308+
y = self.fc1(x)
300309
warnings.warn("Dummy RuntimeWarning.", RuntimeWarning)
301310
if y < 0.0:
302311
x = self.fc2(x)
@@ -806,6 +815,24 @@ def test_disable_warnings(self) -> None:
806815
self.assertTrue(any(uncalled_msg in s for s in cm.output))
807816
self.assertTrue(any(uncalled_modules in s for s in cm.output))
808817

818+
def test_capture_intermediate_outputs(self) -> None:
819+
class TestCaptureNet(nn.Module):
820+
def __init__(self) -> None:
821+
super().__init__()
822+
self.fc1 = nn.Linear(10, 1)
823+
self.fc2 = nn.Linear(10, 10)
824+
825+
def forward(self, x: torch.Tensor) -> torch.Tensor:
826+
y = self.fc1(x)
827+
del y # unused by output
828+
return self.fc2(x) + 2
829+
830+
model = TestCaptureNet()
831+
inputs = (torch.randn((1, 10)),)
832+
analyzer = FlopCountAnalysis(model=model, inputs=inputs)
833+
_ = analyzer.total()
834+
self.assertEqual(analyzer.uncalled_modules(), set())
835+
809836
def test_skip_uncalled_containers_warnings(self) -> None:
810837
# uncalled containers should not warn
811838

0 commit comments

Comments
 (0)