We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
input_output_aliases
pallas_call
1 parent 07d01b4 commit f187196Copy full SHA for f187196
jax/_src/pallas/pallas_call.py
@@ -1556,7 +1556,7 @@ def pallas_call(
1556
etc.
1557
input_output_aliases: a dictionary mapping the index of some inputs to
1558
the index of the output that aliases them. These indices are in the
1559
- flattened inputs and outputs.
+ flattened inputs and outputs (ignoring None values).
1560
debug: if True, Pallas prints various intermediate forms of the kernel
1561
as it is being processed.
1562
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
0 commit comments