Skip to content

Commit 9a2ba6a

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
flop counter: clean-up of tracing logic
Summary: 1. remove check of _get_trace_graph: the old version was removed in pytorch for long enough 2. add a test case about dataparallel 3. move _named_module_without_dup to a separate function so the tracing function looks simpler Reviewed By: ericmintun Differential Revision: D32227000 fbshipit-source-id: fada03028b48db33ef2e2058853c6dda39750251
1 parent a914090 commit 9a2ba6a

File tree

2 files changed

+26
-17
lines changed

2 files changed

+26
-17
lines changed

fvcore/nn/jit_analysis.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import copy
99
from dataclasses import dataclass
1010
from numbers import Number
11-
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
11+
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypeVar, Union
1212

1313
import numpy as np
1414
import torch
@@ -99,6 +99,18 @@ class Statistics:
9999
uncalled_mods: "Set[str]"
100100

101101

102+
def _named_modules_without_dup(model: nn.Module) -> Iterator[Tuple[str, nn.Module]]:
103+
"""
104+
Like .named_modules(), but the results are slightly different for
105+
some wrapped models.
106+
"""
107+
seen = set()
108+
for name, mod in _named_modules_with_dup(model):
109+
if mod not in seen:
110+
seen.add(mod)
111+
yield name, mod
112+
113+
102114
def _get_scoped_trace_graph(
103115
module: nn.Module,
104116
inputs: Union[Tensor, Tuple[Tensor, ...]],
@@ -139,7 +151,6 @@ def __call__(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
139151
tracing_state.pop_scope()
140152
return outputs
141153

142-
seen = set()
143154
hook_handles: List[Any] = []
144155

145156
def register_hooks(mod: nn.Module, name: str) -> None:
@@ -149,31 +160,21 @@ def register_hooks(mod: nn.Module, name: str) -> None:
149160
hook_handles.append(prehook)
150161
hook_handles.append(posthook)
151162

152-
# Torch script does not support parallel torch models, but we still
153-
# want the scope names to be correct for the complete module.
163+
# Unwrap DDP, but correct the scope names for the root module.
154164
if isinstance(
155165
module, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
156166
):
157-
158167
# Since DataParallel just wraps the model, add an extra set of hooks
159168
# to the model it wraps to account for the wrapper. Then trace it.
160169
root_name = aliases[module]
161170
module = module.module
162171
register_hooks(module, root_name)
163172

164-
# We don't need the duplication here, but self._model.named_modules()
165-
# gives slightly different results for some wrapped models.
166-
for name, mod in _named_modules_with_dup(module):
167-
if mod not in seen:
168-
name = aliases[mod]
169-
register_hooks(mod, name)
170-
seen.add(mod)
173+
for name, mod in _named_modules_without_dup(module):
174+
name = aliases[mod]
175+
register_hooks(mod, name)
171176

172-
if hasattr(torch.jit, "get_trace_graph"):
173-
trace, _ = torch.jit.get_trace_graph(module, inputs)
174-
graph = trace.graph()
175-
else:
176-
graph, _ = _get_trace_graph(module, inputs)
177+
graph, _ = _get_trace_graph(module, inputs)
177178

178179
for handle in hook_handles:
179180
handle.remove()

tests/test_jit_model_analysis.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,14 @@ def test_data_parallel(self) -> None:
566566
# Test no uncalled modules
567567
self.assertEqual(analyzer.uncalled_modules(), set())
568568

569+
def test_data_parallel_root_scope(self) -> None:
570+
# A test case discussed in D32227000
571+
model = nn.DataParallel(nn.Linear(10, 10))
572+
for mode in ["caller", "owner"]:
573+
flop = FlopCountAnalysis(model, (torch.randn(10, 10),))
574+
flop.ancestor_mode(mode)
575+
self.assertEqual(flop.total(), 1000)
576+
569577
def test_unsupported_ops(self) -> None:
570578
"""
571579
Tests per-module recording of unsupported operations.

0 commit comments

Comments
 (0)