|
34 | 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
35 | 35 | # SOFTWARE. |
36 | 36 |
|
37 | | -from typing import Dict, Optional, Tuple |
| 37 | +from typing import Dict, Optional, Sequence, Tuple |
38 | 38 |
|
39 | 39 | import pytensor.tensor as pt |
40 | 40 |
|
|
43 | 43 | from pytensor.graph.features import Feature |
44 | 44 | from pytensor.graph.fg import FunctionGraph |
45 | 45 | from pytensor.graph.rewriting.basic import GraphRewriter, node_rewriter |
46 | | -from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB |
| 46 | +from pytensor.graph.rewriting.db import ( |
| 47 | + EquilibriumDB, |
| 48 | + LocalGroupDB, |
| 49 | + RewriteDatabaseQuery, |
| 50 | + SequenceDB, |
| 51 | + TopoDB, |
| 52 | +) |
47 | 53 | from pytensor.tensor.elemwise import DimShuffle, Elemwise |
48 | 54 | from pytensor.tensor.extra_ops import BroadcastTo |
49 | 55 | from pytensor.tensor.random.rewriting import local_subtensor_rv_lift |
50 | | -from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless |
| 56 | +from pytensor.tensor.rewriting.basic import register_canonicalize |
51 | 57 | from pytensor.tensor.rewriting.shape import ShapeFeature |
52 | 58 | from pytensor.tensor.subtensor import ( |
53 | 59 | AdvancedIncSubtensor, |
@@ -191,9 +197,8 @@ def local_lift_DiracDelta(fgraph, node): |
191 | 197 | return new_node.outputs |
192 | 198 |
|
193 | 199 |
|
194 | | -@register_useless |
195 | | -@node_rewriter((DiracDelta,)) |
196 | | -def local_remove_DiracDelta(fgraph, node): |
| 200 | +@node_rewriter([DiracDelta]) |
| 201 | +def remove_DiracDelta(fgraph, node): |
197 | 202 | r"""Remove `DiracDelta`\s.""" |
198 | 203 | dd_val = node.inputs[0] |
199 | 204 | return [dd_val] |
@@ -270,6 +275,17 @@ def incsubtensor_rv_replace(fgraph, node): |
270 | 275 |
|
271 | 276 | logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic") |
272 | 277 |
|
| 278 | +# Rewrites that remove IR Ops |
| 279 | +cleanup_ir_rewrites_db = LocalGroupDB() |
| 280 | +cleanup_ir_rewrites_db.name = "cleanup_ir_rewrites_db" |
| 281 | +logprob_rewrites_db.register( |
| 282 | + "cleanup_ir_rewrites", |
| 283 | + TopoDB(cleanup_ir_rewrites_db, order="out_to_in", ignore_newtrees=True, failure_callback=None), |
| 284 | + "cleanup", |
| 285 | +) |
| 286 | + |
| 287 | +cleanup_ir_rewrites_db.register("remove_DiracDelta", remove_DiracDelta, "cleanup") |
| 288 | + |
273 | 289 |
|
274 | 290 | def construct_ir_fgraph( |
275 | 291 | rv_values: Dict[Variable, Variable], |
@@ -351,3 +367,9 @@ def construct_ir_fgraph( |
351 | 367 | fgraph.replace_all(new_to_old) |
352 | 368 |
|
353 | 369 | return fgraph, rv_values, memo |
| 370 | + |
| 371 | + |
| 372 | +def cleanup_ir(vars: Sequence[Variable]) -> None: |
| 373 | + fgraph = FunctionGraph(outputs=vars, clone=False) |
| 374 | + ir_rewriter = logprob_rewrites_db.query(RewriteDatabaseQuery(include=["cleanup"])) |
| 375 | + ir_rewriter.rewrite(fgraph) |
0 commit comments