Skip to content

Commit f187196

Browse files
Clarify input_output_aliases indexing in pallas_call.
PiperOrigin-RevId: 800063159
1 parent 07d01b4 commit f187196

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/pallas/pallas_call.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1556,7 +1556,7 @@ def pallas_call(
15561556
etc.
15571557
input_output_aliases: a dictionary mapping the index of some inputs to
15581558
the index of the output that aliases them. These indices are in the
1559-
flattened inputs and outputs.
1559+
flattened inputs and outputs (ignoring None values).
15601560
debug: if True, Pallas prints various intermediate forms of the kernel
15611561
as it is being processed.
15621562
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the

0 commit comments

Comments
 (0)