Skip to content

Commit c1ae9e7

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
flop count: do not ignore scripted function
Reviewed By: bxiong1202 Differential Revision: D32233496 fbshipit-source-id: d7cdbb2e1107bd5a12a952ea44f2eb56355de2df
1 parent 9a2ba6a commit c1ae9e7

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

fvcore/nn/jit_analysis.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -575,14 +575,8 @@ def _analyze(self) -> "Statistics":
575575
ancestors = self._get_all_ancestors(scope_names[-1])
576576
all_seen.update(ancestors)
577577
if kind not in self._op_handles:
578-
# Ignore all prim:: operators. However, prim::PythonOp can be
579-
# a user-implemented `torch.autograd.Function` so we shouldn't
580-
# ignore it.
581-
if kind in self._ignored_ops or (
582-
kind.startswith("prim::") and not kind.startswith("prim::PythonOp")
583-
):
578+
if self._should_ignore_node(node):
584579
continue
585-
586580
for name in ancestors:
587581
unsupported_ops[name][kind] += 1
588582
else:
@@ -640,3 +634,16 @@ def _has_forward(self, mod_name: str) -> bool:
640634
if module_type.forward is mod.forward: # pyre-ignore[16]
641635
return False
642636
return True
637+
638+
def _should_ignore_node(self, node) -> bool:
639+
kind = node.kind()
640+
if kind in self._ignored_ops:
641+
return True
642+
# Ignore all prim:: operators, with two exceptions:
643+
# * prim::PythonOp can be a user-implemented `torch.autograd.Function`
644+
# * prim::CallFunction an be a call to scripted module/function.
645+
if kind.startswith("prim::PythonOp") or kind.startswith("prim::CallFunction"):
646+
return False
647+
if kind.startswith("prim::"):
648+
return True
649+
return False

tests/test_flop_count.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,21 @@ def forward(self, x):
829829
)
830830
self.assertEqual(flop.total(), 42)
831831

832+
def test_scripted_function(self):
833+
# Scripted function is not yet supported. It should produce a warning
834+
835+
def func(x):
836+
return x @ x
837+
838+
class Mod(nn.Module):
839+
def forward(self, x):
840+
f = torch.jit.script(func)
841+
return f(x * x)
842+
843+
flop = FlopCountAnalysis(Mod(), (torch.rand(5, 5),))
844+
_ = flop.total()
845+
self.assertIn("prim::CallFunction", flop.unsupported_ops())
846+
832847

833848
class TestFlopCountHandles(unittest.TestCase):
834849
def _count_function(self, func, inputs, name) -> Tuple[Any, Any]:

0 commit comments

Comments
 (0)