Skip to content

Commit ddcf053

Browse files
Luca CitiricardoV94
authored andcommitted
Added log1mexp(log(x)) -> log1p(-x) and its test
Also implemented tests as suggested by ricardoV94
1 parent b9ea6df commit ddcf053

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def local_log_sqrt(fgraph, node):
576576

577577

578578
@register_specialize
579-
@node_rewriter([exp, expm1, softplus])
579+
@node_rewriter([exp, expm1, log1pexp, log1mexp])
580580
def local_exp_log_nan_switch(fgraph, node):
581581
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
582582
x = node.inputs[0]
@@ -629,13 +629,20 @@ def local_exp_log_nan_switch(fgraph, node):
629629
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
630630
return [new_out]
631631

632-
# Case for softplus(log(x)) -> log1p(x)
632+
# Case for log1pexp(log(x)) -> log1p(x) (log1pexp aka softplus)
633633
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus):
634634
x = x.owner.inputs[0]
635635
old_out = node.outputs[0]
636636
new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype))
637637
return [new_out]
638638

639+
# Case for log1mexp(log(x)) -> log1p(-x)
640+
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Log1mexp):
641+
x = x.owner.inputs[0]
642+
old_out = node.outputs[0]
643+
new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype))
644+
return [new_out]
645+
639646

640647
@register_canonicalize
641648
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
log,
7070
log1mexp,
7171
log1p,
72+
log1pexp,
7273
lt,
7374
maximum,
7475
minimum,
@@ -1968,27 +1969,53 @@ def test_exp_softplus(self, exp_op):
19681969
decimal=6,
19691970
)
19701971

1971-
def test_softplus_log(self):
1972-
# softplus(log(x)) -> log1p(x)
1972+
def test_log1pexp_log(self):
1973+
# log1pexp(log(x)) -> log1p(x)
19731974
data_valid = np.random.random((4, 3)).astype("float32") * 2
19741975
data_valid[0, 0] = 0 # edge case
19751976
data_invalid = data_valid - 2
19761977

19771978
x = fmatrix()
1978-
f = function([x], softplus(log(x)), mode=self.mode)
1979-
graph = f.maker.fgraph.toposort()
1980-
ops_graph = [
1981-
node
1982-
for node in graph
1983-
if isinstance(node.op, Elemwise)
1984-
and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus)
1985-
]
1986-
assert len(ops_graph) == 0
1979+
f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace"))
1980+
assert equal_computations(
1981+
f.maker.fgraph.outputs,
1982+
[
1983+
pt.switch(
1984+
x >= np.array([[0]], dtype=np.int8),
1985+
pt.log1p(x),
1986+
np.array([[np.nan]], dtype=np.float32),
1987+
)
1988+
],
1989+
)
19871990

19881991
expected = np.log1p(data_valid)
19891992
np.testing.assert_almost_equal(f(data_valid), expected)
19901993
assert np.all(np.isnan(f(data_invalid)))
19911994

1995+
def test_log1mexp_log(self):
1996+
# log1mexp(log(x)) -> log1p(-x)
1997+
data_valid = np.random.random((4, 3)).astype("float32")
1998+
data_valid[0, 0] = 0 # edge case
1999+
data_valid[0, 1] = 1 # another edge case
2000+
data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1])
2001+
2002+
x = fmatrix()
2003+
f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace"))
2004+
assert equal_computations(
2005+
f.maker.fgraph.outputs,
2006+
[
2007+
pt.switch(
2008+
x >= np.array([[0]], dtype=np.int8),
2009+
pt.log1p(-x),
2010+
np.array([[np.nan]], dtype=np.float32),
2011+
)
2012+
],
2013+
)
2014+
2015+
expected = np.log1p(-data_valid)
2016+
np.testing.assert_almost_equal(f(data_valid), expected)
2017+
assert np.all(np.isnan(f(data_invalid)))
2018+
19922019
@pytest.mark.parametrize(
19932020
["nested_expression", "expected_switches"],
19942021
[

0 commit comments

Comments
 (0)