Skip to content

Commit b9ea6df

Browse files
Luca CitiricardoV94
authored andcommitted
Add rewrite for softplus(log(x)) -> log1p(x)
1 parent 75b8ee9 commit b9ea6df

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def local_log_sqrt(fgraph, node):
576576

577577

578578
@register_specialize
579-
@node_rewriter([exp, expm1])
579+
@node_rewriter([exp, expm1, softplus])
580580
def local_exp_log_nan_switch(fgraph, node):
581581
# Rewrites of the kind exp(log...(x)) that require a `nan` switch
582582
x = node.inputs[0]
@@ -629,6 +629,13 @@ def local_exp_log_nan_switch(fgraph, node):
629629
new_out = switch(le(x, 0), neg(exp(x)), np.asarray(np.nan, old_out.dtype))
630630
return [new_out]
631631

632+
# Case for softplus(log(x)) -> log1p(x)
633+
if isinstance(prev_op, ps.Log) and isinstance(node_op, ps_math.Softplus):
634+
x = x.owner.inputs[0]
635+
old_out = node.outputs[0]
636+
new_out = switch(ge(x, 0), log1p(x), np.asarray(np.nan, old_out.dtype))
637+
return [new_out]
638+
632639

633640
@register_canonicalize
634641
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1968,6 +1968,27 @@ def test_exp_softplus(self, exp_op):
19681968
decimal=6,
19691969
)
19701970

1971+
def test_softplus_log(self):
1972+
# softplus(log(x)) -> log1p(x)
1973+
data_valid = np.random.random((4, 3)).astype("float32") * 2
1974+
data_valid[0, 0] = 0 # edge case
1975+
data_invalid = data_valid - 2
1976+
1977+
x = fmatrix()
1978+
f = function([x], softplus(log(x)), mode=self.mode)
1979+
graph = f.maker.fgraph.toposort()
1980+
ops_graph = [
1981+
node
1982+
for node in graph
1983+
if isinstance(node.op, Elemwise)
1984+
and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus)
1985+
]
1986+
assert len(ops_graph) == 0
1987+
1988+
expected = np.log1p(data_valid)
1989+
np.testing.assert_almost_equal(f(data_valid), expected)
1990+
assert np.all(np.isnan(f(data_invalid)))
1991+
19711992
@pytest.mark.parametrize(
19721993
["nested_expression", "expected_switches"],
19731994
[

0 commit comments

Comments
 (0)