Skip to content

Commit 4cd6e96

Browse files
malfetpytorchmergebot
authored andcommitted
[MPSInductor] Fix nested loop var elimination (pytorch#156566)
As reduction resuts must be kept around Add regression test that is specific for this issue Fixes pytorch#156426 Pull Request resolved: pytorch#156566 Approved by: https://github.com/dcci
1 parent d55dc00 commit 4cd6e96

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5530,6 +5530,13 @@ def fn(x):
55305530
(torch.randn([1, 2, 4, 8]),),
55315531
)
55325532

5533+
def test_var_mean_div_by(self):
5534+
def fn(x):
5535+
var, mean = torch.var_mean(x, dim=2, keepdim=True)
5536+
return x / var, var, mean
5537+
5538+
self.common(fn, (torch.rand([1, 17, 2048]),))
5539+
55335540
def test_var_correction(self):
55345541
def fn(x):
55355542
dim = -1

torch/_inductor/codegen/mps.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,16 @@ def codegen_body(self) -> None:
754754
self.body.splice(self.compute)
755755
self.body.writeline("}" * len(self.multistage_reduction_entry))
756756
# Invalidate variables instantiated inside loop
757-
self.cse.invalidate(OrderedSet(self.cse.reduction_cache.values()))
757+
# But results of reduction alive. Reduction cache values can be
758+
# either CSEVariable or tuple of CSEVariables, in which case all
759+
# variables in the tuple must be preserved
760+
self.cse.invalidate(
761+
OrderedSet(
762+
v
763+
for item in self.cse.reduction_cache.values()
764+
for v in (item if isinstance(item, tuple) else (item,))
765+
)
766+
)
758767
# And loop codegen
759768
while self.multistage_reduction_entry:
760769
self.multistage_reduction_entry.pop().cache_clear()

0 commit comments

Comments
 (0)