From a8e73333b8dddb6bd248765e3f7c08f72bb8fcfb Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 2 Sep 2025 13:59:54 -0700 Subject: [PATCH] Clarify `input_output_aliases` indexing in `pallas_call`. PiperOrigin-RevId: 802281892 --- jax/_src/pallas/pallas_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 7922a9a75862..6afe530724bb 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1565,7 +1565,7 @@ def pallas_call( etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the - flattened inputs and outputs. + flattened inputs and outputs (ignoring None values). debug: if True, Pallas prints various intermediate forms of the kernel as it is being processed. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the