From 8e1c0085ad36f1fc1eda9b102b5dd6430a154ff0 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 5 May 2025 23:03:44 +0000 Subject: [PATCH 001/103] _valu --- keras/src/backend/jax/core.py | 82 ++++++++++++++++++++++++++++++++++ keras/src/backend/jax/layer.py | 15 ++++++- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 3e8657e4caa0..7166b732d0f8 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -3,6 +3,7 @@ import jax.numpy as jnp import ml_dtypes import numpy as np +from flax import nnx from keras.src import tree from keras.src.backend.common import KerasVariable @@ -19,7 +20,88 @@ IS_THREAD_SAFE = True +def in_stateless_scope(): + return global_state.get_global_attribute("stateless_scope") is not None + + +def get_stateless_scope(): + return global_state.get_global_attribute("stateless_scope") + + +def shape_equal(a_shape, b_shape): + """Return whether a_shape == b_shape (allows None entries).""" + if len(a_shape) != len(b_shape): + return False + for e1, e2 in zip(a_shape, b_shape): + if e1 is not None and e2 is not None and e1 != e2: + return False + return True + + class Variable(KerasVariable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + ): + value = initializer(shape, dtype=dtype) + + # Store in nnx.Param (raw_value will be used as backing store) + self._param = nnx.Param(value) + super().__init__( + initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) + + @property + def value(self): + """The current value of the variable (numpy array or backend tensor).""" + if in_stateless_scope(): + scope = get_stateless_scope() + value = scope.get_current_value(self._param.raw_value) + if value is not None: + return self._maybe_autocast(value) + if self._value is None: + # Uninitialized variable. Return a placeholder. + # This is fine because it's only ever used + # in during shape inference / graph tracing + # (anything else would be a bug, to be fixed.) + return self._maybe_autocast( + self._initializer(self._shape, dtype=self._dtype) + ) + return self._maybe_autocast(self._param.raw_value) + + def assign(self, value): + self._param.raw_value = jnp.array(value, dtype=self.dtype) + value = self._convert_to_tensor(value, dtype=self.dtype) + if not shape_equal(value.shape, self.shape): + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"variable.shape={self.value.shape}, " + f"Received: value.shape={value.shape}. " + f"Target variable: {self}" + ) + if in_stateless_scope(): + scope = get_stateless_scope() + scope.add_update((self, value)) + else: + self._direct_assign(value) + return value + def _initialize(self, value): # Note that variable.shape is needed by distribution_lib self._shape = self._validate_shape(value.shape) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index fbcc4fe5b5c6..f70470c40cc3 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,2 +1,13 @@ -class JaxLayer: - pass +from flax import nnx + + +class JaxLayer(nnx.Module): + def __new__(cls, *args, **kwargs): + """Overrides __new__ to save constructor arguments for potential + serialization/config. + """ + instance = super(JaxLayer, cls).__new__(cls) + vars(instance)['_object__state'] = nnx.object.ObjectState() + instance.__init_args = args + instance.__init_kwargs = kwargs + return instance From 7159709bee7995c8a27e65c76cfaf73543b71ee6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 7 May 2025 21:46:28 +0000 Subject: [PATCH 002/103] update variables --- keras/src/backend/jax/core.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 7166b732d0f8..f55cc8f99f0b 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -53,7 +53,13 @@ def __init__( value = initializer(shape, dtype=dtype) # Store in nnx.Param (raw_value will be used as backing store) - self._param = nnx.Param(value) + value = initializer(shape, dtype) + if trainable: + self._param = nnx.Param(value) + else: + self._param = nnx.Variable(value) + self.trainable = trainable + self._name = name super().__init__( initializer, shape=shape, @@ -70,7 +76,7 @@ def value(self): """The current value of the variable (numpy array or backend tensor).""" if in_stateless_scope(): scope = get_stateless_scope() - value = scope.get_current_value(self._param.raw_value) + value = scope.get_current_value(self._param.value) if value is not None: return self._maybe_autocast(value) if self._value is None: @@ -84,7 +90,7 @@ def value(self): return self._maybe_autocast(self._param.raw_value) def assign(self, value): - self._param.raw_value = jnp.array(value, dtype=self.dtype) + self._param.value = jnp.array(value, dtype=self.dtype) value = self._convert_to_tensor(value, dtype=self.dtype) if not shape_equal(value.shape, self.shape): raise ValueError( @@ -101,6 +107,13 @@ def assign(self, value): else: self._direct_assign(value) return value + + def numpy(self): + return jax.device_get(self._param.value) + + def __array__(self): + return self._param.value + def _initialize(self, value): # Note that variable.shape is needed by distribution_lib From e378cfbb868853fd12366f80ecaa779243b48c9c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 8 May 2025 23:09:03 +0000 Subject: [PATCH 003/103] add nnx.jit --- examples/demo_jax_distributed.py | 6 ++--- guides/distributed_training_with_jax.py | 4 +-- integration_tests/jax_custom_fit_test.py | 2 ++ keras/src/backend/jax/core.py | 23 ++++++++---------- keras/src/backend/jax/layer.py | 31 ++++++++++++++++-------- keras/src/backend/jax/trainer.py | 11 +++++---- keras/src/random/random_test.py | 7 ++---- keras/src/random/seed_generator_test.py | 5 ++-- 8 files changed, 48 insertions(+), 41 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 906dc47563de..9bee7c48f792 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -12,7 +12,7 @@ import jax.numpy as jnp import tensorflow as tf # just for tf.data import keras # Keras multi-backend - +from flax import nnx import numpy as np from tqdm import tqdm @@ -264,7 +264,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): # Training step: Keras provides a pure functional optimizer.stateless_apply -@jax.jit +@nnx.jit def train_step(train_state, x, y): (loss_value, non_trainable_variables), grads = compute_gradients( train_state.trainable_variables, @@ -302,7 +302,7 @@ def train_step(train_state, x, y): sharded_data = jax.device_put(data.numpy(), data_sharding) -@jax.jit +@nnx.jit def predict(data): predictions, updated_non_trainable_variables = model.stateless_call( device_train_state.trainable_variables, diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 6f6dbbf25d78..41604a2f3ff0 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -48,7 +48,7 @@ import numpy as np import tensorflow as tf import keras - +from flax import nnx from jax.experimental import mesh_utils from jax.sharding import Mesh from jax.sharding import NamedSharding @@ -186,7 +186,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): # Training step, Keras provides a pure functional optimizer.stateless_apply -@jax.jit +@nnx.jit def train_step(train_state, x, y): ( trainable_variables, diff --git a/integration_tests/jax_custom_fit_test.py b/integration_tests/jax_custom_fit_test.py index 9c9eee59f114..d69b3e35219c 100644 --- a/integration_tests/jax_custom_fit_test.py +++ b/integration_tests/jax_custom_fit_test.py @@ -30,6 +30,7 @@ def compute_loss_and_updates( return loss, (y_pred, non_trainable_variables) def train_step(self, state, data): + print("inside train step with data", data) ( trainable_variables, non_trainable_variables, @@ -91,6 +92,7 @@ def metrics(self): model.compile(optimizer="adam") x = np.random.random((64, 32)) y = np.random.random((64, 1)) + history = model.fit(x, y, epochs=1) assert "loss" in history.history diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index f55cc8f99f0b..e9dbd7665180 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -50,16 +50,6 @@ def __init__( synchronization="auto", name=None, ): - value = initializer(shape, dtype=dtype) - - # Store in nnx.Param (raw_value will be used as backing store) - value = initializer(shape, dtype) - if trainable: - self._param = nnx.Param(value) - else: - self._param = nnx.Variable(value) - self.trainable = trainable - self._name = name super().__init__( initializer, shape=shape, @@ -70,6 +60,14 @@ def __init__( synchronization=synchronization, name=name, ) + value = initializer(shape, dtype=dtype) + + # Store in nnx.Param (raw_value will be used as backing store) + value = initializer(shape, dtype) + if trainable: + self._param = nnx.Param(value) + else: + self._param = nnx.Variable(value) @property def value(self): @@ -107,14 +105,13 @@ def assign(self, value): else: self._direct_assign(value) return value - + def numpy(self): return jax.device_get(self._param.value) - + def __array__(self): return self._param.value - def _initialize(self, value): # Note that variable.shape is needed by distribution_lib self._shape = self._validate_shape(value.shape) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index f70470c40cc3..b56091160631 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,13 +1,24 @@ -from flax import nnx +class JaxLayer: + pass +"""from flax import nnx +import jax.numpy as jnp + class JaxLayer(nnx.Module): - def __new__(cls, *args, **kwargs): - """Overrides __new__ to save constructor arguments for potential - serialization/config. - """ - instance = super(JaxLayer, cls).__new__(cls) - vars(instance)['_object__state'] = nnx.object.ObjectState() - instance.__init_args = args - instance.__init_kwargs = kwargs - return instance + def __init__(self): + super().__init__() + + def add_weight(self, name, shape, dtype=None, initializer=None, trainable=True): + value = initializer(shape, dtype) + var = nnx.Param(value) if trainable else nnx.Variable(value) + setattr(self, name, var) + return var + + def get_weights(self): + return [v.value for v in nnx.variables(self, nnx.Param)] + + def set_weights(self, weights): + params = list(nnx.variables(self, nnx.Param).values()) + for var, val in zip(params, weights): + var.value = val """ diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index e7a978ecd6c5..cd72a0f1971d 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -4,6 +4,7 @@ import jax import numpy as np +from flax import nnx from keras.src import backend from keras.src import callbacks as callbacks_module @@ -231,7 +232,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - concatenate = jax.jit(concatenate) + concatenate = nnx.jit(concatenate) def iterator_step(state, iterator): data = next(iterator) @@ -275,7 +276,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - train_step = jax.jit(self.train_step, donate_argnums=0) + train_step = nnx.jit(self.train_step, donate_argnums=0) else: train_step = self.train_step @@ -291,7 +292,7 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - test_step = jax.jit(self.test_step, donate_argnums=0) + test_step = nnx.jit(self.test_step, donate_argnums=0) else: test_step = self.test_step @@ -308,7 +309,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jax.jit(predict_step, donate_argnums=0) + predict_step = nnx.jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -897,7 +898,7 @@ def _enforce_jax_state_sharding( Since the output of the train/eval step will be used as inputs to next step, we need to ensure that they have the same sharding spec, so that - jax.jit won't have to recompile the train/eval function. + nnx.jit won't have to recompile the train/eval function. Note that this function will also rely on the recorded sharding spec for each of states. diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py index 9e78b8748b4d..0df730b6da5a 100644 --- a/keras/src/random/random_test.py +++ b/keras/src/random/random_test.py @@ -380,12 +380,11 @@ def test_uniform_dtype_validation(self): reason="This test requires `jax` as the backend.", ) def test_dropout_jax_jit_stateless(self): - import jax import jax.numpy as jnp x = ops.ones(3) - @jax.jit + @nnx.jit def train_step(x): with keras.src.backend.StatelessScope(): x = keras.layers.Dropout(rate=0.1)(x, training=True) @@ -414,9 +413,7 @@ def test_jax_rngkey_seed(self): reason="This test requires `jax` as the backend.", ) def test_jax_unseed_disallowed_during_tracing(self): - import jax - - @jax.jit + @nnx.jit def jit_fn(): return random.randint((2, 2), 0, 10, seed=None) diff --git a/keras/src/random/seed_generator_test.py b/keras/src/random/seed_generator_test.py index d1101e0a871a..a042e165a7c3 100644 --- a/keras/src/random/seed_generator_test.py +++ b/keras/src/random/seed_generator_test.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from flax import nnx from keras.src import backend from keras.src import ops @@ -78,9 +79,7 @@ def test_seed_generator_unexpected_kwargs(self): backend.backend() != "jax", reason="This test requires the JAX backend" ) def test_jax_tracing_with_global_seed_generator(self): - import jax - - @jax.jit + @nnx.jit def traced_function(): return seed_generator.global_seed_generator().next() From 8701fc774ed67363b7308ed73b28d2f40b40c0ea Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 8 May 2025 23:12:56 +0000 Subject: [PATCH 004/103] revert changes to JaxLayer --- keras/src/backend/jax/layer.py | 31 ++++++++++--------------------- keras/src/random/random_test.py | 1 + 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index b56091160631..c8936c5f1c71 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,24 +1,13 @@ -class JaxLayer: - pass +from flax import nnx -"""from flax import nnx -import jax.numpy as jnp - class JaxLayer(nnx.Module): - def __init__(self): - super().__init__() - - def add_weight(self, name, shape, dtype=None, initializer=None, trainable=True): - value = initializer(shape, dtype) - var = nnx.Param(value) if trainable else nnx.Variable(value) - setattr(self, name, var) - return var - - def get_weights(self): - return [v.value for v in nnx.variables(self, nnx.Param)] - - def set_weights(self, weights): - params = list(nnx.variables(self, nnx.Param).values()) - for var, val in zip(params, weights): - var.value = val """ + def __new__(cls, *args, **kwargs): + """Overrides __new__ to save constructor arguments for potential + serialization/config. + """ + instance = super(JaxLayer, cls).__new__(cls) + vars(instance)["_object__state"] = nnx.object.ObjectState() + instance.__init_args = args + instance.__init_kwargs = kwargs + return instance diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py index 0df730b6da5a..d93f6d4557db 100644 --- a/keras/src/random/random_test.py +++ b/keras/src/random/random_test.py @@ -1,6 +1,7 @@ import numpy as np import pytest from absl.testing import parameterized +from flax import nnx import keras from keras.src import backend From 4f7b3b87544b7f24092d51ae1cca13be9b0eac96 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 8 May 2025 23:30:13 +0000 Subject: [PATCH 005/103] format fix --- keras/src/backend/jax/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 335a98c97369..641280908313 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -42,6 +42,7 @@ class Variable(KerasVariable): def __init__( self, initializer, + layout=None, shape=None, dtype=None, trainable=True, @@ -132,6 +133,8 @@ def _initialize(self, value): def _direct_assign(self, value): if self._layout is not None: value = distribution_lib.distribute_variable(value, self._layout) + self._param.value = jnp.array(value, dtype=self.dtype) + value = self._convert_to_tensor(value, dtype=self.dtype) self._value = value def _convert_to_tensor(self, value, dtype=None): From 0234e27c8beb29fa9f775ecb1333c7263a4c8e9f Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 13 May 2025 20:50:54 +0000 Subject: [PATCH 006/103] make variables subclass nnx.Variable --- keras/src/backend/jax/core.py | 48 +++++------------------------------ 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 641280908313..ff98c5d958ab 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -38,47 +38,16 @@ def shape_equal(a_shape, b_shape): return True -class Variable(KerasVariable): - def __init__( - self, - initializer, - layout=None, - shape=None, - dtype=None, - trainable=True, - autocast=True, - aggregation="none", - synchronization="auto", - name=None, - ): - self._layout = layout - super().__init__( - initializer, - shape=shape, - dtype=dtype, - trainable=trainable, - autocast=autocast, - aggregation=aggregation, - synchronization=synchronization, - name=name, - ) - value = initializer(shape, dtype=dtype) - - # Store in nnx.Param (raw_value will be used as backing store) - value = initializer(shape, dtype) - if trainable: - self._param = nnx.Param(value) - else: - self._param = nnx.Variable(value) +class Variable(nnx.Variable, KerasVariable): @property def value(self): """The current value of the variable (numpy array or backend tensor).""" if in_stateless_scope(): scope = get_stateless_scope() - value = scope.get_current_value(self._param.value) + value = scope.get_current_value(self.raw_value) if value is not None: - return self._maybe_autocast(value) + return self._maybe_autocast(self.raw_value) if self._value is None: # Uninitialized variable. Return a placeholder. # This is fine because it's only ever used @@ -87,10 +56,10 @@ def value(self): return self._maybe_autocast( self._initializer(self._shape, dtype=self._dtype) ) - return self._maybe_autocast(self._param.raw_value) + return self._maybe_autocast(self.raw_value) def assign(self, value): - self._param.value = jnp.array(value, dtype=self.dtype) + self.raw_value = jnp.array(value, dtype=self.dtype) value = self._convert_to_tensor(value, dtype=self.dtype) if not shape_equal(value.shape, self.shape): raise ValueError( @@ -108,11 +77,6 @@ def assign(self, value): self._direct_assign(value) return value - def numpy(self): - return jax.device_get(self._param.value) - - def __array__(self): - return self._param.value def _initialize(self, value): # Note that variable.shape is needed by distribution_lib @@ -133,7 +97,7 @@ def _initialize(self, value): def _direct_assign(self, value): if self._layout is not None: value = distribution_lib.distribute_variable(value, self._layout) - self._param.value = jnp.array(value, dtype=self.dtype) + self.raw_value = jnp.array(value, dtype=self.dtype) value = self._convert_to_tensor(value, dtype=self.dtype) self._value = value From b87c4f950e594cae3f6684f2a5912743f9ba51ca Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 13 May 2025 23:08:45 +0000 Subject: [PATCH 007/103] more tweaks --- keras/src/backend/jax/core.py | 55 ++++++++++++++++++++++++++++++---- keras/src/backend/jax/layer.py | 3 ++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index ff98c5d958ab..f10d2431efb5 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -39,15 +39,51 @@ def shape_equal(a_shape, b_shape): class Variable(nnx.Variable, KerasVariable): + def __new__(cls, initializer, shape=None, dtype=None, **kwargs): + if dtype is None: + dtype = jnp.float32 + value = initializer(shape, dtype=dtype) + + # Proper construction of nnx.Variable + instance = nnx.Variable.__new__(cls) + nnx.Variable.__init__(instance, raw_value=value) + return instance + + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + ): + # Regular KerasVariable init + KerasVariable.__init__( + self, + initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) + + self.trainable = trainable + self._name = name @property def value(self): """The current value of the variable (numpy array or backend tensor).""" if in_stateless_scope(): scope = get_stateless_scope() - value = scope.get_current_value(self.raw_value) + value = scope.get_current_value(self) if value is not None: - return self._maybe_autocast(self.raw_value) + return self._maybe_autocast(self) if self._value is None: # Uninitialized variable. Return a placeholder. # This is fine because it's only ever used @@ -77,6 +113,11 @@ def assign(self, value): self._direct_assign(value) return value + def __jax_array__(self): + return self.value + + def __array__(self, dtype=None): + return np.asarray(self.value, dtype=dtype) def _initialize(self, value): # Note that variable.shape is needed by distribution_lib @@ -104,9 +145,13 @@ def _direct_assign(self, value): def _convert_to_tensor(self, value, dtype=None): return convert_to_tensor(value, dtype=dtype, sparse=False) - # Overload native accessor. - def __jax_array__(self): - return self.value + @property + def dtype(self): + return self.raw_value.dtype + + @property + def shape(self): + return self.raw_value.shape def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index c8936c5f1c71..66d735188b11 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -2,6 +2,9 @@ class JaxLayer(nnx.Module): + def __init_subclass__(cls): + super().__init_subclass__() + def __new__(cls, *args, **kwargs): """Overrides __new__ to save constructor arguments for potential serialization/config. From 91b9a731b36d08ab0b35ef3e1736ce23086627a0 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 27 May 2025 22:41:02 +0000 Subject: [PATCH 008/103] update init --- keras/src/backend/jax/core.py | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index f10d2431efb5..3254fbd7e409 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -49,32 +49,13 @@ def __new__(cls, initializer, shape=None, dtype=None, **kwargs): nnx.Variable.__init__(instance, raw_value=value) return instance - def __init__( - self, - initializer, - shape=None, - dtype=None, - trainable=True, - autocast=True, - aggregation="none", - synchronization="auto", - name=None, - ): - # Regular KerasVariable init - KerasVariable.__init__( - self, - initializer, - shape=shape, - dtype=dtype, - trainable=trainable, - autocast=autocast, - aggregation=aggregation, - synchronization=synchronization, - name=name, - ) + def __init__(self, *args, layout=None, **kwargs): + # Intercept layout parameter so that it is available + # during initialization. + nnx.Variable.__init__(self, *args, **kwargs) + KerasVariable.__init__(self, *args, **kwargs) - self.trainable = trainable - self._name = name + self._layout = layout @property def value(self): From aee17893f5b799412d7b10057bc0aacecb4e2f1c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 28 May 2025 00:10:12 +0000 Subject: [PATCH 009/103] refactor jax Variable class --- keras/src/backend/jax/core.py | 305 ++++++++++++++++++++++++++-------- 1 file changed, 235 insertions(+), 70 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 3254fbd7e409..9cbc2cc87e8a 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -38,73 +38,14 @@ def shape_equal(a_shape, b_shape): return True -class Variable(nnx.Variable, KerasVariable): - def __new__(cls, initializer, shape=None, dtype=None, **kwargs): - if dtype is None: - dtype = jnp.float32 - value = initializer(shape, dtype=dtype) - - # Proper construction of nnx.Variable - instance = nnx.Variable.__new__(cls) - nnx.Variable.__init__(instance, raw_value=value) - return instance - +# existing implementation +class JaxVariable(KerasVariable): def __init__(self, *args, layout=None, **kwargs): - # Intercept layout parameter so that it is available - # during initialization. - nnx.Variable.__init__(self, *args, **kwargs) - KerasVariable.__init__(self, *args, **kwargs) - - self._layout = layout - - @property - def value(self): - """The current value of the variable (numpy array or backend tensor).""" - if in_stateless_scope(): - scope = get_stateless_scope() - value = scope.get_current_value(self) - if value is not None: - return self._maybe_autocast(self) - if self._value is None: - # Uninitialized variable. Return a placeholder. - # This is fine because it's only ever used - # in during shape inference / graph tracing - # (anything else would be a bug, to be fixed.) - return self._maybe_autocast( - self._initializer(self._shape, dtype=self._dtype) - ) - return self._maybe_autocast(self.raw_value) - - def assign(self, value): - self.raw_value = jnp.array(value, dtype=self.dtype) - value = self._convert_to_tensor(value, dtype=self.dtype) - if not shape_equal(value.shape, self.shape): - raise ValueError( - "The shape of the target variable and " - "the shape of the target value in " - "`variable.assign(value)` must match. " - f"variable.shape={self.value.shape}, " - f"Received: value.shape={value.shape}. " - f"Target variable: {self}" - ) - if in_stateless_scope(): - scope = get_stateless_scope() - scope.add_update((self, value)) - else: - self._direct_assign(value) - return value - - def __jax_array__(self): - return self.value - - def __array__(self, dtype=None): - return np.asarray(self.value, dtype=dtype) + self._layout = layout + super().__init__(*args, **kwargs) def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib self._shape = self._validate_shape(value.shape) - # We can't import the keras/distribution/distribution_lib - # due to circular dependency. distribution = global_state.get_global_attribute("distribution") if self._layout is None and distribution is not None: tensor_layout = distribution.get_variable_layout(self) @@ -119,20 +60,244 @@ def _initialize(self, value): def _direct_assign(self, value): if self._layout is not None: value = distribution_lib.distribute_variable(value, self._layout) - self.raw_value = jnp.array(value, dtype=self.dtype) - value = self._convert_to_tensor(value, dtype=self.dtype) self._value = value def _convert_to_tensor(self, value, dtype=None): return convert_to_tensor(value, dtype=dtype, sparse=False) - @property - def dtype(self): - return self.raw_value.dtype + def __jax_array__(self): + return self.value + + +class Variable(JaxVariable, nnx.Variable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + # Keras specific args from KerasVariable + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + # Keras JAX backend specific + layout=None, + # NNX specific args + nnx_mutable=None, # NNX's own mutable flag + *args, + **nnx_metadata, # For nnx.Variable's **metadata + ): + # We need to call KerasJaxVariableImpl.__init__ first + # KerasJaxVariableImpl's __init__ takes `layout` specifically + # and forwards other Keras common args to CommonKerasVariable.__init__ + super( + JaxVariable, self + ).__init__( # Explicitly call KerasJaxVariableImpl's __init__ + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + layout=layout, + *args, + ) + + # Store NNX args for potential deferred initialization + self._nnx_mutable_arg = nnx_mutable + self._nnx_metadata_arg = nnx_metadata.copy() + self._nnx_init_pending = True + + # If Keras initialization was not deferred, self._value is now set. + # So we can proceed to initialize the nnx.Variable part. + if self._initializer is None: + self._complete_nnx_init() + + def _complete_nnx_init(self): + """Initializes the nnx.Variable part of this instance.""" + if not self._nnx_init_pending: + return # Already done + + if self._value is None: + # This can happen if _deferred_initialize was called but _value somehow didn't get set + # Or if this is called too early. Keras's _initializer should not be None here. + raise ValueError( + "Cannot initialize NNX part: Keras self._value is None, " + "but Keras initializer is also None (should not be deferred)." + ) + + # Determine nnx_mutable for nnx.Variable.__init__ + # If user didn't specify nnx_mutable, default to Keras's trainable status. + current_nnx_mutable = self._nnx_mutable_arg + if current_nnx_mutable is None: + current_nnx_mutable = self.trainable # A sensible default link + + # initialize the nnx.Variable + nnx.Variable.__init__( + self, + value=self._value, + mutable=current_nnx_mutable, + **self._nnx_metadata_arg, + ) + self._nnx_init_pending = False + + def _deferred_initialize(self): + # This is called by Keras when it's time to actually create the variable's value + super()._deferred_initialize() + self._complete_nnx_init() + + def _direct_assign(self, value): + super()._direct_assign(value) # This sets self._value + + # After self._value is updated by Keras, sync nnx.Variable.raw_value + # Only if NNX part is already initialized. + if not self._nnx_init_pending: + nnx_stores_mutable = False + if ( + self._nnx_mutable_arg is None + ): # Check how nnx_mutable was resolved + nnx_stores_mutable = self.trainable + else: + nnx_stores_mutable = self._nnx_mutable_arg + + if nnx_stores_mutable and nnx.utils.is_mutable_array( + self.raw_value + ): + # If raw_value is a mutable_array, update its content + self.raw_value[...] = self._value + else: + object.__setattr__(self, "raw_value", self._value) @property - def shape(self): - return self.raw_value.shape + def value(self): + # This will be KerasVariable.value: + return super().value + + @value.setter + def value(self, new_value): + self.assign( + new_value + ) # assign will call _direct_assign, which syncs raw_value + + # Overriding NNX methods that modify `raw_value` or `_var_metadata` directly + # to ensure Keras's `_value` and other Keras states are in sync. + + def copy_from(self, other: nnx.Variable): # type: ignore + if not isinstance(other, nnx.Variable): # Basic check from nnx + raise TypeError( + f"Expected nnx.Variable, got {type(other).__name__}" + ) + if not isinstance(other, Variable): + pass + + # Let nnx.Variable handle its part (updates self.raw_value and self._var_metadata) + # Need to call nnx.Variable.copy_from specifically. + nnx.Variable.copy_from(self, other) + + # Now, self.raw_value is updated. Sync Keras's self._value. + # Extract the JAX array if raw_value is a nnx.mutable_array + keras_value_to_assign = self.raw_value + if nnx.utils.is_mutable_array(keras_value_to_assign): + keras_value_to_assign = keras_value_to_assign.__array__() + + self.assign(keras_value_to_assign) + + # Sync Keras-specific attributes if `other` is also a JaxNnxVariable + if isinstance(other, Variable): + self.trainable = other.trainable + self._autocast = other._autocast + self._aggregation = other._aggregation + if hasattr(other, "_layout"): + self._layout = other._layout + + def update_from_state(self, variable_state: nnx.graph.VariableState): + # Let nnx.Variable handle its part (updates self.raw_value and self._var_metadata) + nnx.Variable.update_from_state(self, variable_state) + + # Sync Keras's self._value + keras_value_to_assign = self.raw_value + if nnx.utils.is_mutable_array(keras_value_to_assign): + keras_value_to_assign = keras_value_to_assign.__array__() + + self.assign(keras_value_to_assign) + + # Sync Keras attributes if they were part of variable_state.metadata + if "trainable" in variable_state._var_metadata: # type: ignore + self.trainable = variable_state._var_metadata["trainable"] + self._autocast = variable_state._var_metadata["autocast"] + + def __getstate__(self): + keras_state = { + # Keras common attributes (from CommonKerasVariable) + "_name": self._name, + "_path": self._path, + "_trainable": self._trainable, + "_dtype": self._dtype, + "_shape": self._shape, + "_autocast": self._autocast, + "_aggregation": self._aggregation, + "_synchronization": self._synchronization, + "_regularizer": self._regularizer, + "_constraint": self._constraint, + # Keras JAX backend specific + "_layout": self._layout, + # Value itself (will be part of nnx_state's raw_value too) + "_value": self._value, # Keras's value (JAX array) + "_initializer": self._initializer, # In case it's not initialized + # NNX specific args that were stored at init + "_nnx_mutable_arg": self._nnx_mutable_arg, + "_nnx_metadata_arg": self._nnx_metadata_arg, + "_nnx_init_pending": self._nnx_init_pending, + } + nnx_state = nnx.Variable.__getstate__(self) + return {"keras_state": keras_state, "nnx_state": nnx_state} + + def __setstate__(self, state): + keras_state = state["keras_state"] + nnx_state = state["nnx_state"] + + # Restore Keras attributes + for k, v in keras_state.items(): + object.__setattr__(self, k, v) + + # Restore NNX attributes using its __setstate__ + nnx.Variable.__setstate__(self, nnx_state) + + if ( + self._initializer is not None and self._value is None + ): # Was deferred pre-pickle + if ( + not self._nnx_init_pending + and hasattr(self, "raw_value") + and self.raw_value is not None + ): + pass # self._value is already set from keras_state. + + # If self._value exists (from Keras state), ensure nnx.raw_value matches + if self._value is not None: + if self._nnx_init_pending: + self._complete_nnx_init() + else: + # This is similar to _direct_assign's sync logic. + current_nnx_mutable = self._nnx_mutable_arg + if current_nnx_mutable is None: + current_nnx_mutable = self.trainable + + if current_nnx_mutable and nnx.utils.is_mutable_array( + self.raw_value + ): + self.raw_value[...] = self._value + else: + object.__setattr__(self, "raw_value", self._value) + elif ( + not self._nnx_init_pending + and hasattr(self, "raw_value") + and self.raw_value is not None + ): + object.__setattr__(self, "_value", self.raw_value) def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): From 141487f53d1b56e5d362bba2655a9188a791157a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 28 May 2025 00:23:13 +0000 Subject: [PATCH 010/103] code reformat --- keras/src/backend/jax/core.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 9cbc2cc87e8a..1a5f3561b746 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -122,15 +122,14 @@ def _complete_nnx_init(self): return # Already done if self._value is None: - # This can happen if _deferred_initialize was called but _value somehow didn't get set - # Or if this is called too early. Keras's _initializer should not be None here. raise ValueError( "Cannot initialize NNX part: Keras self._value is None, " "but Keras initializer is also None (should not be deferred)." ) # Determine nnx_mutable for nnx.Variable.__init__ - # If user didn't specify nnx_mutable, default to Keras's trainable status. + # If user didn't specify nnx_mutable, default to Keras's trainable + # status. current_nnx_mutable = self._nnx_mutable_arg if current_nnx_mutable is None: current_nnx_mutable = self.trainable # A sensible default link @@ -145,7 +144,8 @@ def _complete_nnx_init(self): self._nnx_init_pending = False def _deferred_initialize(self): - # This is called by Keras when it's time to actually create the variable's value + # This is called by Keras when it's time to actually create the + # variable's value super()._deferred_initialize() self._complete_nnx_init() @@ -193,7 +193,8 @@ def copy_from(self, other: nnx.Variable): # type: ignore if not isinstance(other, Variable): pass - # Let nnx.Variable handle its part (updates self.raw_value and self._var_metadata) + # Let nnx.Variable handle its part (updates self.raw_value and + # self._var_metadata) # Need to call nnx.Variable.copy_from specifically. nnx.Variable.copy_from(self, other) @@ -214,7 +215,8 @@ def copy_from(self, other: nnx.Variable): # type: ignore self._layout = other._layout def update_from_state(self, variable_state: nnx.graph.VariableState): - # Let nnx.Variable handle its part (updates self.raw_value and self._var_metadata) + # Let nnx.Variable handle its part (updates self.raw_value and + # self._var_metadata) nnx.Variable.update_from_state(self, variable_state) # Sync Keras's self._value From 5ccc31e4ccf3c4675665883335c74dc4d5953a93 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 28 May 2025 00:27:13 +0000 Subject: [PATCH 011/103] more cleanup --- keras/src/backend/jax/core.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 1a5f3561b746..e16c775a12b1 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -20,32 +20,19 @@ IS_THREAD_SAFE = True -def in_stateless_scope(): - return global_state.get_global_attribute("stateless_scope") is not None - - -def get_stateless_scope(): - return global_state.get_global_attribute("stateless_scope") - - -def shape_equal(a_shape, b_shape): - """Return whether a_shape == b_shape (allows None entries).""" - if len(a_shape) != len(b_shape): - return False - for e1, e2 in zip(a_shape, b_shape): - if e1 is not None and e2 is not None and e1 != e2: - return False - return True - - # existing implementation class JaxVariable(KerasVariable): def __init__(self, *args, layout=None, **kwargs): + # Intercept layout parameter so that it is available + # during initialization. self._layout = layout super().__init__(*args, **kwargs) def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. distribution = global_state.get_global_attribute("distribution") if self._layout is None and distribution is not None: tensor_layout = distribution.get_variable_layout(self) @@ -65,6 +52,7 @@ def _direct_assign(self, value): def _convert_to_tensor(self, value, dtype=None): return convert_to_tensor(value, dtype=dtype, sparse=False) + # Overload native accessor. def __jax_array__(self): return self.value From dd9c77d75c53bcb6c5879eef1788b9d21d44d6c6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 28 May 2025 19:57:53 +0000 Subject: [PATCH 012/103] update flax version --- requirements-jax-cuda.txt | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 7fd5763924b5..55d754f3570f 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,6 +9,6 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax +flax>=0.10.6 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 8d150a4e989e..fbcdd34f52ac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax +flax>=0.10.6 # Common deps. -r requirements-common.txt From 48983c6195d5f9b699099f15c71b2a3199ddb4a1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 28 May 2025 20:00:17 +0000 Subject: [PATCH 013/103] update flax version --- requirements-jax-cuda.txt | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 55d754f3570f..d2358c0c0cde 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,6 +9,6 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax>=0.10.6 +flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index fbcdd34f52ac..c76360c7a501 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax>=0.10.6 +flax>=0.10.1 # Common deps. -r requirements-common.txt From b22f9ef5d7addb876ff62794a4bd3104fa4f74d9 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 29 May 2025 21:33:08 +0000 Subject: [PATCH 014/103] fix jax error --- integration_tests/import_test.py | 2 +- requirements-common.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index e7af37f23c83..e2cd5484ca68 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -11,7 +11,7 @@ "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), - "jax": ("jax[cpu]", ""), + "jax": ("jax[cpu]==0.5.0", ""), "openvino": ("openvino", ""), } diff --git a/requirements-common.txt b/requirements-common.txt index 7edc40c97a1a..21ec0efe7cdd 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,3 +24,4 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino +flax>=0.10.1 From 4dbffa6e634f6af25e240ce691d19f99dfd55c32 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 29 May 2025 22:08:37 +0000 Subject: [PATCH 015/103] update Variables implementation --- keras/src/backend/jax/core.py | 382 ++++++++++++++-------------------- 1 file changed, 161 insertions(+), 221 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index e16c775a12b1..2d912ada92d2 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import ml_dtypes import numpy as np -from flax import nnx +from flax.experimental import nnx from keras.src import tree from keras.src.backend.common import KerasVariable @@ -12,6 +12,8 @@ from keras.src.backend.common.keras_tensor import KerasTensor from keras.src.backend.common.name_scope import name_scope as base_name_scope from keras.src.backend.common.stateless_scope import StatelessScope +from keras.src.backend.common.stateless_scope import get_stateless_scope +from keras.src.backend.common.stateless_scope import in_stateless_scope from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib @@ -20,274 +22,212 @@ IS_THREAD_SAFE = True -# existing implementation -class JaxVariable(KerasVariable): - def __init__(self, *args, layout=None, **kwargs): - # Intercept layout parameter so that it is available - # during initialization. - self._layout = layout - super().__init__(*args, **kwargs) - - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) - # We can't import the keras/distribution/distribution_lib - # due to circular dependency. - distribution = global_state.get_global_attribute("distribution") - if self._layout is None and distribution is not None: - tensor_layout = distribution.get_variable_layout(self) - from keras.src.distribution import TensorLayout - - if isinstance(tensor_layout, TensorLayout): - self._layout = tensor_layout.backend_layout - else: - self._layout = tensor_layout - self._direct_assign(value) - - def _direct_assign(self, value): - if self._layout is not None: - value = distribution_lib.distribute_variable(value, self._layout) - self._value = value - - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype, sparse=False) - - # Overload native accessor. - def __jax_array__(self): - return self.value - - -class Variable(JaxVariable, nnx.Variable): +class Variable(KerasVariable, nnx.Variable): def __init__( self, initializer, shape=None, dtype=None, trainable=True, - # Keras specific args from KerasVariable autocast=True, aggregation="none", synchronization="auto", name=None, - # Keras JAX backend specific layout=None, - # NNX specific args - nnx_mutable=None, # NNX's own mutable flag - *args, - **nnx_metadata, # For nnx.Variable's **metadata + mutable=None, + **nnx_metadata, ): - # We need to call KerasJaxVariableImpl.__init__ first - # KerasJaxVariableImpl's __init__ takes `layout` specifically - # and forwards other Keras common args to CommonKerasVariable.__init__ - super( - JaxVariable, self - ).__init__( # Explicitly call KerasJaxVariableImpl's __init__ + # Determine NNX mutability. This needs to be known for nnx.Variable.__init__. + if mutable is None: + actual_nnx_mutable = ( + trainable # Keras 'trainable' maps to NNX 'mutable' + ) + else: + actual_nnx_mutable = mutable + + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' param takes precedence. + if "mutable" in nnx_metadata and mutable is not None: + nnx_metadata["mutable"] = actual_nnx_mutable + elif "mutable" not in nnx_metadata: + nnx_metadata["mutable"] = actual_nnx_mutable + + # Initialize nnx.Variable first. + if shape is not None and dtype is not None: + # If initializer is a Keras callable, it's not ready yet. + # If initializer is already a value, KerasVariable will handle it. + # We need a concrete array for the placeholder. + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(dtype) + ) + elif shape is not None: + _placeholder_value = jnp.zeros(shape, dtype=jnp.float32) + else: + _placeholder_value = jnp.array(0.0, dtype=jnp.float32) + + # Call nnx.Variable.__init__ directly. + nnx.Variable.__init__(self, value=_placeholder_value, **nnx_metadata) + + # Store JAX-specific layout using object.__setattr__ BEFORE KerasVariable init. + # This is because KerasVariable.__init__ will call self._initialize, which uses self._layout. + object.__setattr__(self, "_layout", layout) + + # Initialize KerasVariable. + super(Variable, self).__init__( initializer=initializer, shape=shape, dtype=dtype, - trainable=trainable, + trainable=trainable, # Pass Keras trainable autocast=autocast, aggregation=aggregation, synchronization=synchronization, name=name, - layout=layout, - *args, ) - # Store NNX args for potential deferred initialization - self._nnx_mutable_arg = nnx_mutable - self._nnx_metadata_arg = nnx_metadata.copy() - self._nnx_init_pending = True + # self._value now holds the true JAX array from KerasVariable init. + # Update nnx.Variable's internal value (raw_value) to match. + object.__setattr__(self, "raw_value", self._value) - # If Keras initialization was not deferred, self._value is now set. - # So we can proceed to initialize the nnx.Variable part. - if self._initializer is None: - self._complete_nnx_init() + def __getstate__(self): + # Get the state from KerasVariable (attributes in __dict__) + # KerasVariable does not have a custom __getstate__, so we mimic default behavior. + keras_state = self.__dict__.copy() - def _complete_nnx_init(self): - """Initializes the nnx.Variable part of this instance.""" - if not self._nnx_init_pending: - return # Already done + # Get the state from nnx.Variable + nnx_specific_state = super(KerasVariable, self).__getstate__() - if self._value is None: - raise ValueError( - "Cannot initialize NNX part: Keras self._value is None, " - "but Keras initializer is also None (should not be deferred)." - ) + # Merge them. Keras state is primary. NNX specific state adds to it. + if "raw_value" in nnx_specific_state: + keras_state["_value"] = nnx_specific_state["raw_value"] - # Determine nnx_mutable for nnx.Variable.__init__ - # If user didn't specify nnx_mutable, default to Keras's trainable - # status. - current_nnx_mutable = self._nnx_mutable_arg - if current_nnx_mutable is None: - current_nnx_mutable = self.trainable # A sensible default link - - # initialize the nnx.Variable - nnx.Variable.__init__( - self, - value=self._value, - mutable=current_nnx_mutable, - **self._nnx_metadata_arg, - ) - self._nnx_init_pending = False + # Add NNX attributes that are not in Keras's __dict__ + if "_trace_state" in nnx_specific_state: + keras_state["_trace_state"] = nnx_specific_state["_trace_state"] + if "_var_metadata" in nnx_specific_state: + keras_state["_var_metadata"] = nnx_specific_state["_var_metadata"] - def _deferred_initialize(self): - # This is called by Keras when it's time to actually create the - # variable's value - super()._deferred_initialize() - self._complete_nnx_init() + # Remove elements that might be problematic or redundant if nnx.Variable's __getstate__ + keras_state.pop("raw_value", None) - def _direct_assign(self, value): - super()._direct_assign(value) # This sets self._value + return keras_state - # After self._value is updated by Keras, sync nnx.Variable.raw_value - # Only if NNX part is already initialized. - if not self._nnx_init_pending: - nnx_stores_mutable = False - if ( - self._nnx_mutable_arg is None - ): # Check how nnx_mutable was resolved - nnx_stores_mutable = self.trainable - else: - nnx_stores_mutable = self._nnx_mutable_arg + def __setstate__(self, state): + # Separate nnx specific keys that we added if they are not part of Keras __dict__ + # Our __getstate__ puts them into the main state dictionary. + nnx_raw_value = state["_value"] # This was raw_value + nnx_trace_state = state.pop("_trace_state", None) + nnx_var_metadata = state.pop("_var_metadata", None) - if nnx_stores_mutable and nnx.utils.is_mutable_array( - self.raw_value - ): - # If raw_value is a mutable_array, update its content - self.raw_value[...] = self._value - else: - object.__setattr__(self, "raw_value", self._value) + # Populate the instance's __dict__ with the Keras attributes. + self.__dict__.update(state) - @property - def value(self): - # This will be KerasVariable.value: - return super().value - - @value.setter - def value(self, new_value): - self.assign( - new_value - ) # assign will call _direct_assign, which syncs raw_value - - # Overriding NNX methods that modify `raw_value` or `_var_metadata` directly - # to ensure Keras's `_value` and other Keras states are in sync. - - def copy_from(self, other: nnx.Variable): # type: ignore - if not isinstance(other, nnx.Variable): # Basic check from nnx - raise TypeError( - f"Expected nnx.Variable, got {type(other).__name__}" - ) - if not isinstance(other, Variable): + # restore the nnx.Variable specific slotted attributes. + object.__setattr__(self, "raw_value", nnx_raw_value) + + if nnx_trace_state is not None: + object.__setattr__(self, "_trace_state", nnx_trace_state) + else: pass - # Let nnx.Variable handle its part (updates self.raw_value and - # self._var_metadata) - # Need to call nnx.Variable.copy_from specifically. - nnx.Variable.copy_from(self, other) + if nnx_var_metadata is not None: + object.__setattr__(self, "_var_metadata", nnx_var_metadata) + else: + pass - # Now, self.raw_value is updated. Sync Keras's self._value. - # Extract the JAX array if raw_value is a nnx.mutable_array - keras_value_to_assign = self.raw_value - if nnx.utils.is_mutable_array(keras_value_to_assign): - keras_value_to_assign = keras_value_to_assign.__array__() + # Ensure Keras's self._value is also consistent with the restored raw_value + object.__setattr__(self, "_value", nnx_raw_value) - self.assign(keras_value_to_assign) + if hasattr(self, "_shape") and self._shape is not None: + self._ndim = len(self._shape) + else: + # Fallback if shape isn't immediately available. + self._ndim = len(self.raw_value.shape) - # Sync Keras-specific attributes if `other` is also a JaxNnxVariable - if isinstance(other, Variable): - self.trainable = other.trainable - self._autocast = other._autocast - self._aggregation = other._aggregation - if hasattr(other, "_layout"): - self._layout = other._layout + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout - def update_from_state(self, variable_state: nnx.graph.VariableState): - # Let nnx.Variable handle its part (updates self.raw_value and - # self._var_metadata) - nnx.Variable.update_from_state(self, variable_state) + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout + else: + self._layout = tensor_layout + self._direct_assign(value) - # Sync Keras's self._value - keras_value_to_assign = self.raw_value - if nnx.utils.is_mutable_array(keras_value_to_assign): - keras_value_to_assign = keras_value_to_assign.__array__() + def _direct_assign(self, value): + # Apply JAX-specific distribution if layout is present + if self._layout is not None: + processed_value = distribution_lib.distribute_variable( + value, self._layout + ) + else: + processed_value = value - self.assign(keras_value_to_assign) + # Update Keras's internal _value + self._value = processed_value - # Sync Keras attributes if they were part of variable_state.metadata - if "trainable" in variable_state._var_metadata: # type: ignore - self.trainable = variable_state._var_metadata["trainable"] - self._autocast = variable_state._var_metadata["autocast"] + # Ensure that nnx.Variable part is initialized + if not hasattr(self, "_var_metadata"): + # todo: should add a warning + pass - def __getstate__(self): - keras_state = { - # Keras common attributes (from CommonKerasVariable) - "_name": self._name, - "_path": self._path, - "_trainable": self._trainable, - "_dtype": self._dtype, - "_shape": self._shape, - "_autocast": self._autocast, - "_aggregation": self._aggregation, - "_synchronization": self._synchronization, - "_regularizer": self._regularizer, - "_constraint": self._constraint, - # Keras JAX backend specific - "_layout": self._layout, - # Value itself (will be part of nnx_state's raw_value too) - "_value": self._value, # Keras's value (JAX array) - "_initializer": self._initializer, # In case it's not initialized - # NNX specific args that were stored at init - "_nnx_mutable_arg": self._nnx_mutable_arg, - "_nnx_metadata_arg": self._nnx_metadata_arg, - "_nnx_init_pending": self._nnx_init_pending, - } - nnx_state = nnx.Variable.__getstate__(self) - return {"keras_state": keras_state, "nnx_state": nnx_state} + # Apply on_set_value hook if it exists + if ( + hasattr(self, "_var_metadata") + and "on_set_value" in self._var_metadata + ): + final_value = self._var_metadata["on_set_value"]( + self, processed_value + ) + else: + final_value = processed_value - def __setstate__(self, state): - keras_state = state["keras_state"] - nnx_state = state["nnx_state"] + # Directly set raw_value. nnx.Variable handles mutable array updates + object.__setattr__(self, "raw_value", final_value) - # Restore Keras attributes - for k, v in keras_state.items(): - object.__setattr__(self, k, v) + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype, sparse=False) - # Restore NNX attributes using its __setstate__ - nnx.Variable.__setstate__(self, nnx_state) + # Overload native accessor. + def __jax_array__(self): + return self.value - if ( - self._initializer is not None and self._value is None - ): # Was deferred pre-pickle + @property + def value(self): + if not hasattr(self, "raw_value"): + if not hasattr(self, "_value") or self._value is None: + if self._initializer is not None: + initial_value = self._initializer( + self._shape, dtype=self._dtype + ) + return self._maybe_autocast(initial_value) + else: + raise AttributeError( + "Variable is not properly initialized and has no initializer." + ) + current_value = self._value + else: + current_value = self.raw_value + # NNX specific: if raw_value is a mutable_array wrapper, get the actual array. if ( - not self._nnx_init_pending - and hasattr(self, "raw_value") - and self.raw_value is not None + hasattr(self, "_var_metadata") + and "on_get_value" in self._var_metadata ): - pass # self._value is already set from keras_state. + current_value = self._var_metadata["on_get_value"]( + self, current_value + ) - # If self._value exists (from Keras state), ensure nnx.raw_value matches - if self._value is not None: - if self._nnx_init_pending: - self._complete_nnx_init() - else: - # This is similar to _direct_assign's sync logic. - current_nnx_mutable = self._nnx_mutable_arg - if current_nnx_mutable is None: - current_nnx_mutable = self.trainable + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) - if current_nnx_mutable and nnx.utils.is_mutable_array( - self.raw_value - ): - self.raw_value[...] = self._value - else: - object.__setattr__(self, "raw_value", self._value) - elif ( - not self._nnx_init_pending - and hasattr(self, "raw_value") - and self.raw_value is not None - ): - object.__setattr__(self, "_value", self.raw_value) + return self._maybe_autocast(current_value) def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): From 627e58146cf3f2f3dc13b36e185f20fc85ed2fbb Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 29 May 2025 22:21:27 +0000 Subject: [PATCH 016/103] fix import --- keras/src/backend/jax/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 2d912ada92d2..e9e510c2c6e8 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import ml_dtypes import numpy as np -from flax.experimental import nnx +from flax import nnx from keras.src import tree from keras.src.backend.common import KerasVariable From f58ef60a4fe1c31cf169633966f253b025a7d5e5 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 29 May 2025 22:39:09 +0000 Subject: [PATCH 017/103] add a test --- keras/src/backend/jax/core_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 keras/src/backend/jax/core_test.py diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py new file mode 100644 index 000000000000..1ecad0112638 --- /dev/null +++ b/keras/src/backend/jax/core_test.py @@ -0,0 +1,28 @@ + +import jax.numpy as jnp +import pytest +from flax import nnx + +from keras.src import backend +from keras.src.backend.jax.core import Variable as KerasJaxVariable +from keras.src import testing + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX backend specific test for core Variable integration with NNX.", +) +class JaxCoreVariableTest(testing.TestCase): + def test_variable_in_nnx_module(self): + class Model(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.custom_variable = KerasJaxVariable(jnp.ones((1, 3))) + def __call__(self, x): + return self.linear(x) + self.custom_variable + + model = Model(rngs=nnx.Rngs(0)) + self.assertTrue(hasattr(model.custom_variable,"_trace_state")) + self.assertIsNotNone(model.custom_variable._trace_state) + self.assertAllEqual(model.custom_variable.value, [[1, 1, 1]]) + self.assertTrue(isinstance(model.custom_variable, nnx.Variable)) From c2b73b73a1eb52aa53a9526f856a00605c1dc63e Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 29 May 2025 23:12:38 +0000 Subject: [PATCH 018/103] needs updates in operation --- keras/src/ops/operation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index cd3123be3b33..1ffdc3bd5c51 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -14,6 +14,8 @@ @keras_export("keras.Operation") class Operation: + def __init_subclass__(cls): + super().__init_subclass__() def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) @@ -119,7 +121,9 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) - + if backend.backend()=="jax": + from flax import nnx + vars(instance)['_object__state'] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) From a662f5ec0c524aed919944f66290e28203a09700 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 30 May 2025 16:58:31 +0000 Subject: [PATCH 019/103] remove __new__ from JaxLayer --- keras/src/backend/jax/layer.py | 10 ---------- keras/src/ops/operation.py | 8 +++++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 66d735188b11..70c486d3e178 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -4,13 +4,3 @@ class JaxLayer(nnx.Module): def __init_subclass__(cls): super().__init_subclass__() - - def __new__(cls, *args, **kwargs): - """Overrides __new__ to save constructor arguments for potential - serialization/config. - """ - instance = super(JaxLayer, cls).__new__(cls) - vars(instance)["_object__state"] = nnx.object.ObjectState() - instance.__init_args = args - instance.__init_kwargs = kwargs - return instance diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 1ffdc3bd5c51..de4786a7c992 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -16,6 +16,7 @@ class Operation: def __init_subclass__(cls): super().__init_subclass__() + def __init__(self, dtype=None, name=None): if name is None: name = auto_name(self.__class__.__name__) @@ -121,9 +122,10 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) - if backend.backend()=="jax": - from flax import nnx - vars(instance)['_object__state'] = nnx.object.ObjectState() + if backend.backend() == "jax": + from flax import nnx + + vars(instance)["_object__state"] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) From 396f973a2efb90b2e40e960a72f74b51ceae660e Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 30 May 2025 19:03:48 +0000 Subject: [PATCH 020/103] update base optimizers --- keras/src/optimizers/base_optimizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index a996e9945cc8..a4dcbeab0d40 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -774,6 +774,8 @@ def _get_current_learning_rate(self): self._learning_rate, learning_rate_schedule.LearningRateSchedule ): return self._learning_rate(self._iterations) + elif isinstance(self._learning_rate, backend.Variable): + return self._learning_rate elif callable(self._learning_rate): return self._learning_rate() return self._learning_rate From 30e971d4c884f6c7c95dfe504cf095ec2bba1745 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 30 May 2025 20:08:27 +0000 Subject: [PATCH 021/103] code reformat+ model saving tests --- keras/src/backend/jax/core.py | 30 ++++++++++++------- keras/src/backend/jax/core_test.py | 48 ++++++++++++++++++++++-------- 2 files changed, 55 insertions(+), 23 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index e9e510c2c6e8..ae7a516f2fe1 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -37,7 +37,8 @@ def __init__( mutable=None, **nnx_metadata, ): - # Determine NNX mutability. This needs to be known for nnx.Variable.__init__. + # Determine NNX mutability. This needs to be known for + # nnx.Variable.__init__. if mutable is None: actual_nnx_mutable = ( trainable # Keras 'trainable' maps to NNX 'mutable' @@ -45,7 +46,8 @@ def __init__( else: actual_nnx_mutable = mutable - # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' param takes precedence. + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' param + # takes precedence. if "mutable" in nnx_metadata and mutable is not None: nnx_metadata["mutable"] = actual_nnx_mutable elif "mutable" not in nnx_metadata: @@ -67,8 +69,10 @@ def __init__( # Call nnx.Variable.__init__ directly. nnx.Variable.__init__(self, value=_placeholder_value, **nnx_metadata) - # Store JAX-specific layout using object.__setattr__ BEFORE KerasVariable init. - # This is because KerasVariable.__init__ will call self._initialize, which uses self._layout. + # Store JAX-specific layout using object.__setattr__ BEFORE + # KerasVariable init. + # This is because KerasVariable.__init__ will call self._initialize, + # which uses self._layout. object.__setattr__(self, "_layout", layout) # Initialize KerasVariable. @@ -89,7 +93,8 @@ def __init__( def __getstate__(self): # Get the state from KerasVariable (attributes in __dict__) - # KerasVariable does not have a custom __getstate__, so we mimic default behavior. + # KerasVariable does not have a custom __getstate__, so we mimic + # default behavior. keras_state = self.__dict__.copy() # Get the state from nnx.Variable @@ -105,13 +110,15 @@ def __getstate__(self): if "_var_metadata" in nnx_specific_state: keras_state["_var_metadata"] = nnx_specific_state["_var_metadata"] - # Remove elements that might be problematic or redundant if nnx.Variable's __getstate__ + # Remove elements that might be problematic or redundant if + # nnx.Variable's __getstate__ keras_state.pop("raw_value", None) return keras_state def __setstate__(self, state): - # Separate nnx specific keys that we added if they are not part of Keras __dict__ + # Separate nnx specific keys that we added if they are not part of + # Keras __dict__ # Our __getstate__ puts them into the main state dictionary. nnx_raw_value = state["_value"] # This was raw_value nnx_trace_state = state.pop("_trace_state", None) @@ -133,7 +140,8 @@ def __setstate__(self, state): else: pass - # Ensure Keras's self._value is also consistent with the restored raw_value + # Ensure Keras's self._value is also consistent with the restored + # raw_value object.__setattr__(self, "_value", nnx_raw_value) if hasattr(self, "_shape") and self._shape is not None: @@ -207,12 +215,14 @@ def value(self): return self._maybe_autocast(initial_value) else: raise AttributeError( - "Variable is not properly initialized and has no initializer." + "Variable is not properly initialized and has no " + "initializer." ) current_value = self._value else: current_value = self.raw_value - # NNX specific: if raw_value is a mutable_array wrapper, get the actual array. + # NNX specific: if raw_value is a mutable_array wrapper, get the + # actual array. if ( hasattr(self, "_var_metadata") and "on_get_value" in self._var_metadata diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 1ecad0112638..06d6670cc02c 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -1,11 +1,14 @@ +import os import jax.numpy as jnp +import numpy as np import pytest from flax import nnx +import keras from keras.src import backend -from keras.src.backend.jax.core import Variable as KerasJaxVariable from keras.src import testing +from keras.src.backend.jax.core import Variable as KerasJaxVariable @pytest.mark.skipif( @@ -13,16 +16,35 @@ reason="JAX backend specific test for core Variable integration with NNX.", ) class JaxCoreVariableTest(testing.TestCase): + def setup(self): + super().setup() + + class NNXModel(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(2, 3, rngs=rngs) + self.custom_variable = KerasJaxVariable(jnp.ones((1, 3))) + + def __call__(self, x): + return self.linear(x) + self.custom_variable + + self.nnx_model = NNXModel(rngs=nnx.Rngs(0)) + self.keras_nnx_model = keras.Sequential( + [keras.layers.Dense(units=1, input_shape=(10,))] + ) + self.single_dummy_input = np.random.rand(1, 10) + def test_variable_in_nnx_module(self): - class Model(nnx.Module): - def __init__(self, rngs): - self.linear = nnx.Linear(2, 3, rngs=rngs) - self.custom_variable = KerasJaxVariable(jnp.ones((1, 3))) - def __call__(self, x): - return self.linear(x) + self.custom_variable - - model = Model(rngs=nnx.Rngs(0)) - self.assertTrue(hasattr(model.custom_variable,"_trace_state")) - self.assertIsNotNone(model.custom_variable._trace_state) - self.assertAllEqual(model.custom_variable.value, [[1, 1, 1]]) - self.assertTrue(isinstance(model.custom_variable, nnx.Variable)) + self.assertTrue(hasattr(self.nnx_model.custom_variable, "_trace_state")) + self.assertIsNotNone(self.nnx_model.custom_variable._trace_state) + self.assertAllEqual(self.nnx_model.custom_variable.value, [[1, 1, 1]]) + self.assertTrue( + isinstance(self.nnx_model.custom_variable, nnx.Variable) + ) + + def test_model_saving(self): + path = os.path.join(self.get_temp_dir(), "model.keras") + original_outputs = self.keras_nnx_model(self.single_dummy_input) + self.keras_nnx_model.save(path, save_format="keras_v3") + restored_model = keras.models.load_model(path) + restored_outputs = restored_model(self.single_dummy_input) + self.assertAllEqual(original_outputs, restored_outputs) From 968d8049c742782af7df143ee45a43606a568508 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 2 Jun 2025 20:11:33 +0000 Subject: [PATCH 022/103] add __hash__ --- keras/src/backend/jax/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index ae7a516f2fe1..69cdca95cd30 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -238,6 +238,9 @@ def value(self): return self._maybe_autocast(stateless_value) return self._maybe_autocast(current_value) + + def __hash__(self): + return id(self) def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): From b99571a79aa3820940e6d0d22a493b0c02319c52 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 2 Jun 2025 20:37:59 +0000 Subject: [PATCH 023/103] update variable value updates --- keras/src/backend/jax/core.py | 27 ++++++++++++++++++++++++++- keras/src/backend/jax/core_test.py | 8 ++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 69cdca95cd30..4afa3534b320 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -53,6 +53,31 @@ def __init__( elif "mutable" not in nnx_metadata: nnx_metadata["mutable"] = actual_nnx_mutable + # hook for NNX value updates + def _sync_keras_value_on_nnx_update(variable, new_nnx_value): + """Updates Keras's _value when NNX's value changes.""" + if hasattr(variable, "_value"): + variable._value = new_nnx_value + return new_nnx_value + + # Add the hook to nnx_metadata + if "on_set_value" in nnx_metadata: + existing_on_set_value = nnx_metadata["on_set_value"] + + def _chained_sync_hook(variable, new_nnx_value): + # Call existing hook + processed_value_by_existing_hook = existing_on_set_value( + variable, new_nnx_value + ) + # Sync Keras's _value + if hasattr(variable, "_value"): + variable._value = processed_value_by_existing_hook + return processed_value_by_existing_hook + + nnx_metadata["on_set_value"] = _chained_sync_hook + else: + nnx_metadata["on_set_value"] = _sync_keras_value_on_nnx_update + # Initialize nnx.Variable first. if shape is not None and dtype is not None: # If initializer is a Keras callable, it's not ready yet. @@ -238,7 +263,7 @@ def value(self): return self._maybe_autocast(stateless_value) return self._maybe_autocast(current_value) - + def __hash__(self): return id(self) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 06d6670cc02c..b0599563c9fe 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -1,5 +1,6 @@ import os +import jax import jax.numpy as jnp import numpy as np import pytest @@ -48,3 +49,10 @@ def test_model_saving(self): restored_model = keras.models.load_model(path) restored_outputs = restored_model(self.single_dummy_input) self.assertAllEqual(original_outputs, restored_outputs) + + def test_keras_variable_nnx_split_merge_sync(self): + variable1 = keras.Variable(jnp.array(1.0)) + graphdef, state = nnx.split(variable1) + state = jax.tree.map(lambda x: x + 1, state) + variable2 = nnx.merge(graphdef, state) + self.assertEqual(variable2._value, variable2.value) From ed0bc009d875cc3ec0066f9190782357fc0614b5 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 2 Jun 2025 20:56:43 +0000 Subject: [PATCH 024/103] sync value properly --- keras/src/backend/jax/core.py | 42 ++++++++--------------------------- 1 file changed, 9 insertions(+), 33 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 4afa3534b320..6ebacbc4cbb4 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -53,31 +53,6 @@ def __init__( elif "mutable" not in nnx_metadata: nnx_metadata["mutable"] = actual_nnx_mutable - # hook for NNX value updates - def _sync_keras_value_on_nnx_update(variable, new_nnx_value): - """Updates Keras's _value when NNX's value changes.""" - if hasattr(variable, "_value"): - variable._value = new_nnx_value - return new_nnx_value - - # Add the hook to nnx_metadata - if "on_set_value" in nnx_metadata: - existing_on_set_value = nnx_metadata["on_set_value"] - - def _chained_sync_hook(variable, new_nnx_value): - # Call existing hook - processed_value_by_existing_hook = existing_on_set_value( - variable, new_nnx_value - ) - # Sync Keras's _value - if hasattr(variable, "_value"): - variable._value = processed_value_by_existing_hook - return processed_value_by_existing_hook - - nnx_metadata["on_set_value"] = _chained_sync_hook - else: - nnx_metadata["on_set_value"] = _sync_keras_value_on_nnx_update - # Initialize nnx.Variable first. if shape is not None and dtype is not None: # If initializer is a Keras callable, it's not ready yet. @@ -112,9 +87,15 @@ def _chained_sync_hook(variable, new_nnx_value): name=name, ) - # self._value now holds the true JAX array from KerasVariable init. - # Update nnx.Variable's internal value (raw_value) to match. - object.__setattr__(self, "raw_value", self._value) + @property + def _value(self): + if hasattr(self, "raw_value"): + return self.raw_value + return None + + @_value.setter + def _value(self, new_keras_value): + self._direct_assign(new_keras_value) def __getstate__(self): # Get the state from KerasVariable (attributes in __dict__) @@ -200,9 +181,6 @@ def _direct_assign(self, value): else: processed_value = value - # Update Keras's internal _value - self._value = processed_value - # Ensure that nnx.Variable part is initialized if not hasattr(self, "_var_metadata"): # todo: should add a warning @@ -246,8 +224,6 @@ def value(self): current_value = self._value else: current_value = self.raw_value - # NNX specific: if raw_value is a mutable_array wrapper, get the - # actual array. if ( hasattr(self, "_var_metadata") and "on_get_value" in self._var_metadata From 460e0e256b0ec9ec3243c9690059c897a37703a6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 22:36:44 +0000 Subject: [PATCH 025/103] update flag based routing between nnx and jax --- examples/demo_jax_distributed.py | 5 +- guides/distributed_training_with_jax.py | 19 +++++++- keras/src/backend/config.py | 63 +++++++++++++++++++++--- keras/src/backend/jax/__init__.py | 6 +++ keras/src/backend/jax/core.py | 44 +++++++++++++++-- keras/src/backend/jax/core_test.py | 65 ++++++++++++++++++++++++- keras/src/backend/jax/layer.py | 6 ++- keras/src/backend/jax/trainer.py | 23 +++++++-- keras/src/layers/layer.py | 6 ++- keras/src/random/random_test.py | 6 +-- keras/src/random/seed_generator_test.py | 4 +- 11 files changed, 217 insertions(+), 30 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index 9bee7c48f792..b54b3cc2f1db 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -12,7 +12,6 @@ import jax.numpy as jnp import tensorflow as tf # just for tf.data import keras # Keras multi-backend -from flax import nnx import numpy as np from tqdm import tqdm @@ -264,7 +263,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): # Training step: Keras provides a pure functional optimizer.stateless_apply -@nnx.jit +@jax.jit def train_step(train_state, x, y): (loss_value, non_trainable_variables), grads = compute_gradients( train_state.trainable_variables, @@ -302,7 +301,7 @@ def train_step(train_state, x, y): sharded_data = jax.device_put(data.numpy(), data_sharding) -@nnx.jit +@jax.jit def predict(data): predictions, updated_non_trainable_variables = model.stateless_call( device_train_state.trainable_variables, diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 41604a2f3ff0..ee57e0992abc 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -53,6 +53,7 @@ from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P +from keras.src.backend.config import is_nnx_backend_enabled def get_model(): @@ -185,8 +186,24 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) +def conditional_jit(condition, *args, **kwargs): + """ + Applies @jax.jit or @nnx.jit based on a condition. + """ + + def decorator(func): + if condition: + print("Using @nnx.jit") + return nnx.jit(func, *args, **kwargs) + else: + print("Using @jax.jit") + return jax.jit(func, *args, **kwargs) + + return decorator + + # Training step, Keras provides a pure functional optimizer.stateless_apply -@nnx.jit +@conditional_jit(is_nnx_backend_enabled) def train_step(train_state, x, y): ( trainable_variables, diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 68f8e1014639..9520074d5735 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -2,6 +2,7 @@ import os from keras.src.api_export import keras_export +from keras.src.backend.common import global_state as keras_global_state # The type of float to use throughout a session. _FLOATX = "float32" @@ -15,6 +16,9 @@ # Default backend: TensorFlow. _BACKEND = "tensorflow" +# Default for NNX backend features. +_NNX_ENABLED_KEY = "nnx_enabled" + # Cap run duration for debugging. _MAX_EPOCHS = None _MAX_STEPS_PER_EPOCH = None @@ -187,9 +191,7 @@ def enable_flash_attention(): used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and input layout requirements may vary depending on the backend. """ - from keras.src.backend.common import global_state - - global_state.set_global_attribute("flash_attention", None) + keras_global_state.set_global_attribute("flash_attention", None) @keras_export("keras.config.disable_flash_attention") @@ -203,9 +205,7 @@ def disable_flash_attention(): Once disabled, supported layers like `MultiHeadAttention` will not use flash attention for faster computations. """ - from keras.src.backend.common import global_state - - global_state.set_global_attribute("flash_attention", False) + keras_global_state.set_global_attribute("flash_attention", False) @keras_export("keras.config.is_flash_attention_enabled") @@ -225,9 +225,45 @@ def is_flash_attention_enabled(): Returns: `False` if disabled; otherwise, it indicates that it is enabled. """ - from keras.src.backend.common import global_state + return keras_global_state.get_global_attribute( + "flash_attention", default=None + ) + + +@keras_export("keras.config.enable_nnx_backend") +def enable_nnx_backend(): + """Enable NNX specific features for the JAX backend. + + When enabled, Keras may utilize NNX-specific optimizations or features + if the JAX backend is active. This is disabled by default. + """ + keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, True) + + +@keras_export("keras.config.disable_nnx_backend") +def disable_nnx_backend(): + """Disable NNX specific features for the JAX backend. + + This function explicitly disables any NNX-specific backend features. + """ + keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, False) + + +@keras_export("keras.config.is_nnx_backend_enabled") +def is_nnx_backend_enabled(): + """Checks whether NNX specific features are enabled for the JAX backend. + + Returns: + bool: `True` if NNX backend features are enabled, `False` otherwise. + Defaults to `False`. + """ + return keras_global_state.get_global_attribute( + _NNX_ENABLED_KEY, default=False + ) + - return global_state.get_global_attribute("flash_attention", default=None) +def set_nnx_backend_enabled(value: bool): + keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, bool(value)) def standardize_data_format(data_format): @@ -274,10 +310,14 @@ def keras_home(): _backend = _config.get("backend", _BACKEND) _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} + _nnx_enabled_json = _config.get(_NNX_ENABLED_KEY, is_nnx_backend_enabled()) + if not isinstance(_nnx_enabled_json, bool): + _nnx_enabled_json = str(_nnx_enabled_json).lower() == "true" set_floatx(_floatx) set_epsilon(_epsilon) set_image_data_format(_image_data_format) + set_nnx_backend_enabled(_nnx_enabled_json) _BACKEND = _backend # Save config file, if possible. @@ -295,6 +335,7 @@ def keras_home(): "epsilon": epsilon(), "backend": _BACKEND, "image_data_format": image_data_format(), + _NNX_ENABLED_KEY: is_nnx_backend_enabled(), } try: with open(_config_path, "w") as f: @@ -312,6 +353,12 @@ def keras_home(): _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) +if "KERAS_NNX_ENABLED" in os.environ: + _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() + if _nnx_enabled_env in ("true", "1"): + set_nnx_backend_enabled(True) + elif _nnx_enabled_env in ("false", "0"): + set_nnx_backend_enabled(False) if _BACKEND != "tensorflow": # If we are not running on the tensorflow backend, we should stop tensorflow diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 12d25effa6fc..b800180723c5 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,3 +1,4 @@ +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.backend.jax import core from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image @@ -10,6 +11,11 @@ from keras.src.backend.jax.core import IS_THREAD_SAFE from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS + +if is_nnx_backend_enabled: + from keras.src.backend.jax.core import NnxVariable as Variable +else: + from keras.src.backend.jax.core import JaxVariable as Variable from keras.src.backend.jax.core import Variable from keras.src.backend.jax.core import cast from keras.src.backend.jax.core import compute_output_spec diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 6ebacbc4cbb4..c0b80d826c96 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -22,7 +22,43 @@ IS_THREAD_SAFE = True -class Variable(KerasVariable, nnx.Variable): +class JaxVariable(KerasVariable): + def __init__(self, *args, layout=None, **kwargs): + # Intercept layout parameter so that it is available + # during initialization. + self._layout = layout + super().__init__(*args, **kwargs) + + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout + + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout + else: + self._layout = tensor_layout + self._direct_assign(value) + + def _direct_assign(self, value): + if self._layout is not None: + value = distribution_lib.distribute_variable(value, self._layout) + self._value = value + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype, sparse=False) + + # Overload native accessor. + def __jax_array__(self): + return self.value + + +class NnxVariable(KerasVariable, nnx.Variable): def __init__( self, initializer, @@ -76,7 +112,7 @@ def __init__( object.__setattr__(self, "_layout", layout) # Initialize KerasVariable. - super(Variable, self).__init__( + super(NnxVariable, self).__init__( initializer=initializer, shape=shape, dtype=dtype, @@ -257,7 +293,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): # an existing distributed jax array will raise error. return x - if isinstance(x, Variable): + if isinstance(x, (JaxVariable, NnxVariable)): if dtype is not None and x.dtype != dtype: return x.value.astype(dtype) return x.value @@ -541,7 +577,7 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): - if isinstance(variable, Variable): + if isinstance(variable, (JaxVariable, NnxVariable)): variable = variable.value return jax.lax.stop_gradient(variable) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index b0599563c9fe..b90d881d521b 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -8,14 +8,22 @@ import keras from keras.src import backend +from keras.src import layers from keras.src import testing -from keras.src.backend.jax.core import Variable as KerasJaxVariable +from keras.src.backend.jax.core import JaxVariable +from keras.src.backend.jax.core import NnxVariable +from keras.src.backend.jax.layer import JaxLayer +from keras.src.backend.jax.layer import NnxLayer @pytest.mark.skipif( backend.backend() != "jax", reason="JAX backend specific test for core Variable integration with NNX.", ) +@pytest.mark.skipif( + not keras.config.is_nnx_backend_enabled(), + reason="Test requires NNX backend to be enabled by default for setup.", +) class JaxCoreVariableTest(testing.TestCase): def setup(self): super().setup() @@ -23,7 +31,9 @@ def setup(self): class NNXModel(nnx.Module): def __init__(self, rngs): self.linear = nnx.Linear(2, 3, rngs=rngs) - self.custom_variable = KerasJaxVariable(jnp.ones((1, 3))) + # Use NnxVariable directly as KerasJaxVariable + # might be JaxVariable if NNX is disabled globally. + self.custom_variable = NnxVariable(jnp.ones((1, 3))) def __call__(self, x): return self.linear(x) + self.custom_variable @@ -56,3 +66,54 @@ def test_keras_variable_nnx_split_merge_sync(self): state = jax.tree.map(lambda x: x + 1, state) variable2 = nnx.merge(graphdef, state) self.assertEqual(variable2._value, variable2.value) + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="JAX backend specific test for NNX flag.", +) +class JaxNnxFlagTest(testing.TestCase): + def tearDown(self): + # Reset the flag to its default (False) after each test + keras.config.disable_nnx_backend() + super().tearDown() + + def test_variable_selection_based_on_nnx_flag(self): + # Test with NNX backend enabled + keras.config.enable_nnx_backend() + self.assertTrue(keras.config.is_nnx_backend_enabled()) + var_nnx_enabled = backend.Variable(1.0) + self.assertIsInstance(var_nnx_enabled, NnxVariable) + self.assertNotIsInstance(var_nnx_enabled, JaxVariable) + + # Test with NNX backend disabled + keras.config.disable_nnx_backend() + self.assertFalse(keras.config.is_nnx_backend_enabled()) + var_nnx_disabled = backend.Variable(1.0) + self.assertIsInstance(var_nnx_disabled, JaxVariable) + self.assertNotIsInstance(var_nnx_disabled, NnxVariable) + + def test_layer_backend_selection_based_on_nnx_flag(self): + # Test with NNX backend enabled + keras.config.enable_nnx_backend() + self.assertTrue(keras.config.is_nnx_backend_enabled()) + + class MyLayerNnxEnabled(layers.Layer): + pass + + layer_nnx_enabled = MyLayerNnxEnabled() + self.assertIsInstance(layer_nnx_enabled, NnxLayer) + self.assertNotIsInstance(layer_nnx_enabled, JaxLayer) + + # Test with NNX backend disabled + # Must clear global state to re-evaluate Layer's base class + keras.src.backend.common.global_state.clear_session() + keras.config.disable_nnx_backend() + self.assertFalse(keras.config.is_nnx_backend_enabled()) + + class MyLayerNnxDisabled(layers.Layer): + pass + + layer_nnx_disabled = MyLayerNnxDisabled() + self.assertIsInstance(layer_nnx_disabled, JaxLayer) + self.assertNotIsInstance(layer_nnx_disabled, NnxLayer) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 70c486d3e178..8d1cd6242bb1 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,6 +1,10 @@ from flax import nnx -class JaxLayer(nnx.Module): +class JaxLayer: + pass + + +class NnxLayer(nnx.Module): def __init_subclass__(cls): super().__init_subclass__() diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 6095dd6e0efe..484034d3ea47 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -13,6 +13,7 @@ from keras.src import tree from keras.src.backend import config from keras.src.backend import distribution_lib as jax_distribution_lib +from keras.src.config import is_nnx_backend_enabled from keras.src.distribution import distribution_lib from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing @@ -234,7 +235,10 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - concatenate = nnx.jit(concatenate) + if is_nnx_backend_enabled: + concatenate = nnx.jit(concatenate) + else: + concatenate = jax.jit(concatenate) def iterator_step(state, iterator): data = next(iterator) @@ -278,7 +282,10 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - train_step = nnx.jit(self.train_step, donate_argnums=0) + if is_nnx_backend_enabled: + train_step = nnx.jit(self.train_step, donate_argnums=0) + else: + train_step = jax.jit(self.train_step, donate_argnums=0) else: train_step = self.train_step @@ -294,7 +301,10 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - test_step = nnx.jit(self.test_step, donate_argnums=0) + if is_nnx_backend_enabled: + test_step = nnx.jit(self.test_step, donate_argnums=0) + else: + test_step = jax.jit(self.test_step, donate_argnums=0) else: test_step = self.test_step @@ -311,7 +321,10 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = nnx.jit(predict_step, donate_argnums=0) + if is_nnx_backend_enabled: + predict_step = nnx.jit(predict_step, donate_argnums=0) + else: + predict_step = jax.jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -905,7 +918,7 @@ def _enforce_jax_state_sharding( Since the output of the train/eval step will be used as inputs to next step, we need to ensure that they have the same sharding spec, so that - nnx.jit won't have to recompile the train/eval function. + nnx.jit/jax.jit won't have to recompile the train/eval function. Note that this function will also rely on the recorded sharding spec for each of states. diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index bd9f1143d91e..72a41210c5fb 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -38,6 +38,7 @@ from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.config import is_nnx_backend_enabled from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec @@ -53,7 +54,10 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - from keras.src.backend.jax.layer import JaxLayer as BackendLayer + if is_nnx_backend_enabled: + from keras.src.backend.jax.layer import NnxLayer as BackendLayer + else: + from keras.src.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": from keras.src.backend.torch.layer import TorchLayer as BackendLayer elif backend.backend() == "numpy": diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py index d93f6d4557db..ae5f103af0e6 100644 --- a/keras/src/random/random_test.py +++ b/keras/src/random/random_test.py @@ -1,7 +1,7 @@ +import jax import numpy as np import pytest from absl.testing import parameterized -from flax import nnx import keras from keras.src import backend @@ -385,7 +385,7 @@ def test_dropout_jax_jit_stateless(self): x = ops.ones(3) - @nnx.jit + @jax.jit def train_step(x): with keras.src.backend.StatelessScope(): x = keras.layers.Dropout(rate=0.1)(x, training=True) @@ -414,7 +414,7 @@ def test_jax_rngkey_seed(self): reason="This test requires `jax` as the backend.", ) def test_jax_unseed_disallowed_during_tracing(self): - @nnx.jit + @jax.jit def jit_fn(): return random.randint((2, 2), 0, 10, seed=None) diff --git a/keras/src/random/seed_generator_test.py b/keras/src/random/seed_generator_test.py index a042e165a7c3..b54e74e5b61d 100644 --- a/keras/src/random/seed_generator_test.py +++ b/keras/src/random/seed_generator_test.py @@ -1,6 +1,6 @@ +import jax import numpy as np import pytest -from flax import nnx from keras.src import backend from keras.src import ops @@ -79,7 +79,7 @@ def test_seed_generator_unexpected_kwargs(self): backend.backend() != "jax", reason="This test requires the JAX backend" ) def test_jax_tracing_with_global_seed_generator(self): - @nnx.jit + @jax.jit def traced_function(): return seed_generator.global_seed_generator().next() From 34f27e9de5025acf127354faae39e19ec5b53556 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 22:48:19 +0000 Subject: [PATCH 026/103] clean up --- integration_tests/jax_custom_fit_test.py | 2 -- keras/src/random/random_test.py | 4 +++- keras/src/random/seed_generator_test.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/integration_tests/jax_custom_fit_test.py b/integration_tests/jax_custom_fit_test.py index d69b3e35219c..9c9eee59f114 100644 --- a/integration_tests/jax_custom_fit_test.py +++ b/integration_tests/jax_custom_fit_test.py @@ -30,7 +30,6 @@ def compute_loss_and_updates( return loss, (y_pred, non_trainable_variables) def train_step(self, state, data): - print("inside train step with data", data) ( trainable_variables, non_trainable_variables, @@ -92,7 +91,6 @@ def metrics(self): model.compile(optimizer="adam") x = np.random.random((64, 32)) y = np.random.random((64, 1)) - history = model.fit(x, y, epochs=1) assert "loss" in history.history diff --git a/keras/src/random/random_test.py b/keras/src/random/random_test.py index ae5f103af0e6..9e78b8748b4d 100644 --- a/keras/src/random/random_test.py +++ b/keras/src/random/random_test.py @@ -1,4 +1,3 @@ -import jax import numpy as np import pytest from absl.testing import parameterized @@ -381,6 +380,7 @@ def test_uniform_dtype_validation(self): reason="This test requires `jax` as the backend.", ) def test_dropout_jax_jit_stateless(self): + import jax import jax.numpy as jnp x = ops.ones(3) @@ -414,6 +414,8 @@ def test_jax_rngkey_seed(self): reason="This test requires `jax` as the backend.", ) def test_jax_unseed_disallowed_during_tracing(self): + import jax + @jax.jit def jit_fn(): return random.randint((2, 2), 0, 10, seed=None) diff --git a/keras/src/random/seed_generator_test.py b/keras/src/random/seed_generator_test.py index b54e74e5b61d..d1101e0a871a 100644 --- a/keras/src/random/seed_generator_test.py +++ b/keras/src/random/seed_generator_test.py @@ -1,4 +1,3 @@ -import jax import numpy as np import pytest @@ -79,6 +78,8 @@ def test_seed_generator_unexpected_kwargs(self): backend.backend() != "jax", reason="This test requires the JAX backend" ) def test_jax_tracing_with_global_seed_generator(self): + import jax + @jax.jit def traced_function(): return seed_generator.global_seed_generator().next() From 427ff821cc6fd8849f522707f77cb97684dee0b1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 23:16:13 +0000 Subject: [PATCH 027/103] fix circular import error --- keras/src/backend/config.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 9520074d5735..b051a27573b0 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -2,7 +2,6 @@ import os from keras.src.api_export import keras_export -from keras.src.backend.common import global_state as keras_global_state # The type of float to use throughout a session. _FLOATX = "float32" @@ -191,7 +190,9 @@ def enable_flash_attention(): used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and input layout requirements may vary depending on the backend. """ - keras_global_state.set_global_attribute("flash_attention", None) + from keras.src.backend.common import global_state + + global_state.set_global_attribute("flash_attention", None) @keras_export("keras.config.disable_flash_attention") @@ -205,7 +206,9 @@ def disable_flash_attention(): Once disabled, supported layers like `MultiHeadAttention` will not use flash attention for faster computations. """ - keras_global_state.set_global_attribute("flash_attention", False) + from keras.src.backend.common import global_state + + global_state.set_global_attribute("flash_attention", False) @keras_export("keras.config.is_flash_attention_enabled") @@ -225,9 +228,9 @@ def is_flash_attention_enabled(): Returns: `False` if disabled; otherwise, it indicates that it is enabled. """ - return keras_global_state.get_global_attribute( - "flash_attention", default=None - ) + from keras.src.backend.common import global_state + + return global_state.get_global_attribute("flash_attention", default=None) @keras_export("keras.config.enable_nnx_backend") @@ -237,7 +240,9 @@ def enable_nnx_backend(): When enabled, Keras may utilize NNX-specific optimizations or features if the JAX backend is active. This is disabled by default. """ - keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, True) + from keras.src.backend.common import global_state + + global_state.set_global_attribute(_NNX_ENABLED_KEY, True) @keras_export("keras.config.disable_nnx_backend") @@ -246,7 +251,9 @@ def disable_nnx_backend(): This function explicitly disables any NNX-specific backend features. """ - keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, False) + from keras.src.backend.common import global_state + + global_state.set_global_attribute(_NNX_ENABLED_KEY, False) @keras_export("keras.config.is_nnx_backend_enabled") @@ -257,13 +264,15 @@ def is_nnx_backend_enabled(): bool: `True` if NNX backend features are enabled, `False` otherwise. Defaults to `False`. """ - return keras_global_state.get_global_attribute( - _NNX_ENABLED_KEY, default=False - ) + from keras.src.backend.common import global_state + + return global_state.get_global_attribute(_NNX_ENABLED_KEY, default=False) def set_nnx_backend_enabled(value: bool): - keras_global_state.set_global_attribute(_NNX_ENABLED_KEY, bool(value)) + from keras.src.backend.common import global_state + + global_state.set_global_attribute(_NNX_ENABLED_KEY, bool(value)) def standardize_data_format(data_format): From c4ee191d4d14c2d5d8621426315ae20e5e6f3de6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 23:21:48 +0000 Subject: [PATCH 028/103] fix is nnx call enabled flag --- keras/src/layers/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 72a41210c5fb..415fc49a3016 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -54,7 +54,7 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - if is_nnx_backend_enabled: + if is_nnx_backend_enabled(): from keras.src.backend.jax.layer import NnxLayer as BackendLayer else: from keras.src.backend.jax.layer import JaxLayer as BackendLayer From 44414dc6f92fde7968144f44dcdfcee88202d06c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 23:35:44 +0000 Subject: [PATCH 029/103] attemptto fix circular import error --- keras/src/backend/config.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index b051a27573b0..5b0831fdc5f1 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -362,12 +362,12 @@ def keras_home(): _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) -if "KERAS_NNX_ENABLED" in os.environ: - _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() - if _nnx_enabled_env in ("true", "1"): - set_nnx_backend_enabled(True) - elif _nnx_enabled_env in ("false", "0"): - set_nnx_backend_enabled(False) +# if "KERAS_NNX_ENABLED" in os.environ: +# _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() +# if _nnx_enabled_env in ("true", "1"): +# set_nnx_backend_enabled(True) +# elif _nnx_enabled_env in ("false", "0"): +# set_nnx_backend_enabled(False) if _BACKEND != "tensorflow": # If we are not running on the tensorflow backend, we should stop tensorflow From 0953d9931b677549ee476ff95945909b187903c6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 3 Jun 2025 23:54:48 +0000 Subject: [PATCH 030/103] try again --- keras/src/layers/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 415fc49a3016..1eff520ba8a2 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -38,7 +38,7 @@ from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope -from keras.src.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec From 6d54a7e5b409c0bd5fc2ca6e878a7e74681c070d Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 4 Jun 2025 00:10:23 +0000 Subject: [PATCH 031/103] fix import error --- keras/src/backend/config.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 5b0831fdc5f1..c77a199f056b 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -306,6 +306,7 @@ def keras_home(): # Attempt to read Keras config file. _config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) +_initial_nnx_value_from_config = False # Safe default before reading config if os.path.exists(_config_path): try: with open(_config_path) as f: @@ -319,14 +320,15 @@ def keras_home(): _backend = _config.get("backend", _BACKEND) _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} - _nnx_enabled_json = _config.get(_NNX_ENABLED_KEY, is_nnx_backend_enabled()) - if not isinstance(_nnx_enabled_json, bool): - _nnx_enabled_json = str(_nnx_enabled_json).lower() == "true" + _initial_nnx_value_from_config = _config.get(_NNX_ENABLED_KEY, False) + if not isinstance(_initial_nnx_value_from_config, bool): + _initial_nnx_value_from_config = ( + str(_initial_nnx_value_from_config).lower() == "true" + ) set_floatx(_floatx) set_epsilon(_epsilon) set_image_data_format(_image_data_format) - set_nnx_backend_enabled(_nnx_enabled_json) _BACKEND = _backend # Save config file, if possible. @@ -339,12 +341,13 @@ def keras_home(): pass if not os.path.exists(_config_path): + _current_nnx_status_for_saving = is_nnx_backend_enabled() _config = { "floatx": floatx(), "epsilon": epsilon(), "backend": _BACKEND, "image_data_format": image_data_format(), - _NNX_ENABLED_KEY: is_nnx_backend_enabled(), + _NNX_ENABLED_KEY: _current_nnx_status_for_saving, } try: with open(_config_path, "w") as f: @@ -353,6 +356,9 @@ def keras_home(): # Except permission denied. pass +# Ensure global state is initialized from config +set_nnx_backend_enabled(_initial_nnx_value_from_config) + # Set backend based on KERAS_BACKEND flag, if applicable. if "KERAS_BACKEND" in os.environ: _backend = os.environ["KERAS_BACKEND"] @@ -362,12 +368,12 @@ def keras_home(): _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) -# if "KERAS_NNX_ENABLED" in os.environ: -# _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() -# if _nnx_enabled_env in ("true", "1"): -# set_nnx_backend_enabled(True) -# elif _nnx_enabled_env in ("false", "0"): -# set_nnx_backend_enabled(False) +if "KERAS_NNX_ENABLED" in os.environ: + _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() + if _nnx_enabled_env in ("true", "1"): + set_nnx_backend_enabled(True) + elif _nnx_enabled_env in ("false", "0"): + set_nnx_backend_enabled(False) if _BACKEND != "tensorflow": # If we are not running on the tensorflow backend, we should stop tensorflow From 64adbaf3b4dab3d376af7d436200382ea5b1f3a3 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 4 Jun 2025 00:13:15 +0000 Subject: [PATCH 032/103] reformat# Please enter the commit message for your changes. Lines starting --- keras/src/backend/config.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index c77a199f056b..04d8b1ce5a94 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -356,8 +356,6 @@ def keras_home(): # Except permission denied. pass -# Ensure global state is initialized from config -set_nnx_backend_enabled(_initial_nnx_value_from_config) # Set backend based on KERAS_BACKEND flag, if applicable. if "KERAS_BACKEND" in os.environ: @@ -368,12 +366,7 @@ def keras_home(): _MAX_EPOCHS = int(os.environ["KERAS_MAX_EPOCHS"]) if "KERAS_MAX_STEPS_PER_EPOCH" in os.environ: _MAX_STEPS_PER_EPOCH = int(os.environ["KERAS_MAX_STEPS_PER_EPOCH"]) -if "KERAS_NNX_ENABLED" in os.environ: - _nnx_enabled_env = os.environ["KERAS_NNX_ENABLED"].lower() - if _nnx_enabled_env in ("true", "1"): - set_nnx_backend_enabled(True) - elif _nnx_enabled_env in ("false", "0"): - set_nnx_backend_enabled(False) + if _BACKEND != "tensorflow": # If we are not running on the tensorflow backend, we should stop tensorflow @@ -381,6 +374,9 @@ def keras_home(): # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" +# Ensure global state is initialized from config +set_nnx_backend_enabled(_initial_nnx_value_from_config) + @keras_export( [ From 6454800b467ad5576a4af91c266e637a79289747 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 4 Jun 2025 00:25:58 +0000 Subject: [PATCH 033/103] This has to fix it --- keras/src/backend/__init__.py | 2 +- keras/src/backend/config.py | 61 +++++++++++++++---------------- keras/src/backend/jax/__init__.py | 3 +- keras/src/backend/jax/trainer.py | 10 ++--- 4 files changed, 37 insertions(+), 39 deletions(-) diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..a200b17c914e 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -39,7 +39,7 @@ from keras.src.backend.tensorflow.core import Variable as BackendVariable elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 - from keras.src.backend.jax.core import Variable as BackendVariable + from keras.src.backend.jax import Variable as BackendVariable elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 04d8b1ce5a94..8470f531ba7d 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -307,6 +307,7 @@ def keras_home(): # Attempt to read Keras config file. _config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) _initial_nnx_value_from_config = False # Safe default before reading config + if os.path.exists(_config_path): try: with open(_config_path) as f: @@ -326,38 +327,13 @@ def keras_home(): str(_initial_nnx_value_from_config).lower() == "true" ) + # Apply basic configs that don't cause circular import set_floatx(_floatx) set_epsilon(_epsilon) set_image_data_format(_image_data_format) _BACKEND = _backend - -# Save config file, if possible. -if not os.path.exists(_KERAS_DIR): - try: - os.makedirs(_KERAS_DIR) - except OSError: - # Except permission denied and potential race conditions - # in multi-threaded environments. - pass - -if not os.path.exists(_config_path): - _current_nnx_status_for_saving = is_nnx_backend_enabled() - _config = { - "floatx": floatx(), - "epsilon": epsilon(), - "backend": _BACKEND, - "image_data_format": image_data_format(), - _NNX_ENABLED_KEY: _current_nnx_status_for_saving, - } - try: - with open(_config_path, "w") as f: - f.write(json.dumps(_config, indent=4)) - except IOError: - # Except permission denied. - pass - - -# Set backend based on KERAS_BACKEND flag, if applicable. +else: + _config = {} if "KERAS_BACKEND" in os.environ: _backend = os.environ["KERAS_BACKEND"] if _backend: @@ -374,9 +350,6 @@ def keras_home(): # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" -# Ensure global state is initialized from config -set_nnx_backend_enabled(_initial_nnx_value_from_config) - @keras_export( [ @@ -461,3 +434,29 @@ def max_steps_per_epoch(): `None`, no limit is applied. """ return _MAX_STEPS_PER_EPOCH + + +if not os.path.exists(_KERAS_DIR): + try: + os.makedirs(_KERAS_DIR) + except OSError: + # Except permission denied and potential race conditions + pass + +if not os.path.exists(_config_path): + _current_nnx_status_for_saving = is_nnx_backend_enabled() + _config_to_save = { + "floatx": floatx(), + "epsilon": epsilon(), + "backend": _BACKEND, # Use the final _BACKEND value + "image_data_format": image_data_format(), + _NNX_ENABLED_KEY: _current_nnx_status_for_saving, + } + try: + with open(_config_path, "w") as f: + f.write(json.dumps(_config_to_save, indent=4)) + except IOError: + # Except permission denied. + pass + +set_nnx_backend_enabled(_initial_nnx_value_from_config) diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index b800180723c5..859d5e3a1d8d 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -12,11 +12,10 @@ from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS -if is_nnx_backend_enabled: +if is_nnx_backend_enabled(): from keras.src.backend.jax.core import NnxVariable as Variable else: from keras.src.backend.jax.core import JaxVariable as Variable -from keras.src.backend.jax.core import Variable from keras.src.backend.jax.core import cast from keras.src.backend.jax.core import compute_output_spec from keras.src.backend.jax.core import cond diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 484034d3ea47..759bdbbdcf68 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -13,7 +13,7 @@ from keras.src import tree from keras.src.backend import config from keras.src.backend import distribution_lib as jax_distribution_lib -from keras.src.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.distribution import distribution_lib from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing @@ -235,7 +235,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - if is_nnx_backend_enabled: + if is_nnx_backend_enabled(): concatenate = nnx.jit(concatenate) else: concatenate = jax.jit(concatenate) @@ -282,7 +282,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - if is_nnx_backend_enabled: + if is_nnx_backend_enabled(): train_step = nnx.jit(self.train_step, donate_argnums=0) else: train_step = jax.jit(self.train_step, donate_argnums=0) @@ -301,7 +301,7 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - if is_nnx_backend_enabled: + if is_nnx_backend_enabled(): test_step = nnx.jit(self.test_step, donate_argnums=0) else: test_step = jax.jit(self.test_step, donate_argnums=0) @@ -321,7 +321,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - if is_nnx_backend_enabled: + if is_nnx_backend_enabled(): predict_step = nnx.jit(predict_step, donate_argnums=0) else: predict_step = jax.jit(predict_step, donate_argnums=0) From 001f112b929d40dac24976f948e93ca322f2edee Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 4 Jun 2025 00:39:28 +0000 Subject: [PATCH 034/103] api gen --- keras/api/_tf_keras/keras/config/__init__.py | 5 +++++ keras/api/config/__init__.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py index 106fd46a3291..536ae6d7b70c 100644 --- a/keras/api/_tf_keras/keras/config/__init__.py +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -8,15 +8,20 @@ from keras.src.backend.config import ( disable_flash_attention as disable_flash_attention, ) +from keras.src.backend.config import disable_nnx_backend as disable_nnx_backend from keras.src.backend.config import ( enable_flash_attention as enable_flash_attention, ) +from keras.src.backend.config import enable_nnx_backend as enable_nnx_backend from keras.src.backend.config import epsilon as epsilon from keras.src.backend.config import floatx as floatx from keras.src.backend.config import image_data_format as image_data_format from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) +from keras.src.backend.config import ( + is_nnx_backend_enabled as is_nnx_backend_enabled, +) from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py index 106fd46a3291..536ae6d7b70c 100644 --- a/keras/api/config/__init__.py +++ b/keras/api/config/__init__.py @@ -8,15 +8,20 @@ from keras.src.backend.config import ( disable_flash_attention as disable_flash_attention, ) +from keras.src.backend.config import disable_nnx_backend as disable_nnx_backend from keras.src.backend.config import ( enable_flash_attention as enable_flash_attention, ) +from keras.src.backend.config import enable_nnx_backend as enable_nnx_backend from keras.src.backend.config import epsilon as epsilon from keras.src.backend.config import floatx as floatx from keras.src.backend.config import image_data_format as image_data_format from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) +from keras.src.backend.config import ( + is_nnx_backend_enabled as is_nnx_backend_enabled, +) from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon From 5f26958a72941dc47392f2cc09cd9d1585084f6a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 4 Jun 2025 05:46:38 +0000 Subject: [PATCH 035/103] remove enable diisable configs -that does not work --- guides/distributed_training_with_jax.py | 2 +- keras/api/_tf_keras/keras/config/__init__.py | 2 - keras/api/config/__init__.py | 2 - keras/src/backend/config.py | 30 +++-------- keras/src/backend/jax/core_test.py | 55 -------------------- 5 files changed, 8 insertions(+), 83 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index ee57e0992abc..90d3491121fa 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -203,7 +203,7 @@ def decorator(func): # Training step, Keras provides a pure functional optimizer.stateless_apply -@conditional_jit(is_nnx_backend_enabled) +@conditional_jit(is_nnx_backend_enabled()) def train_step(train_state, x, y): ( trainable_variables, diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py index 536ae6d7b70c..65e32dd7f4ee 100644 --- a/keras/api/_tf_keras/keras/config/__init__.py +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -8,11 +8,9 @@ from keras.src.backend.config import ( disable_flash_attention as disable_flash_attention, ) -from keras.src.backend.config import disable_nnx_backend as disable_nnx_backend from keras.src.backend.config import ( enable_flash_attention as enable_flash_attention, ) -from keras.src.backend.config import enable_nnx_backend as enable_nnx_backend from keras.src.backend.config import epsilon as epsilon from keras.src.backend.config import floatx as floatx from keras.src.backend.config import image_data_format as image_data_format diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py index 536ae6d7b70c..65e32dd7f4ee 100644 --- a/keras/api/config/__init__.py +++ b/keras/api/config/__init__.py @@ -8,11 +8,9 @@ from keras.src.backend.config import ( disable_flash_attention as disable_flash_attention, ) -from keras.src.backend.config import disable_nnx_backend as disable_nnx_backend from keras.src.backend.config import ( enable_flash_attention as enable_flash_attention, ) -from keras.src.backend.config import enable_nnx_backend as enable_nnx_backend from keras.src.backend.config import epsilon as epsilon from keras.src.backend.config import floatx as floatx from keras.src.backend.config import image_data_format as image_data_format diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 8470f531ba7d..40fe5ef860a4 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -233,29 +233,6 @@ def is_flash_attention_enabled(): return global_state.get_global_attribute("flash_attention", default=None) -@keras_export("keras.config.enable_nnx_backend") -def enable_nnx_backend(): - """Enable NNX specific features for the JAX backend. - - When enabled, Keras may utilize NNX-specific optimizations or features - if the JAX backend is active. This is disabled by default. - """ - from keras.src.backend.common import global_state - - global_state.set_global_attribute(_NNX_ENABLED_KEY, True) - - -@keras_export("keras.config.disable_nnx_backend") -def disable_nnx_backend(): - """Disable NNX specific features for the JAX backend. - - This function explicitly disables any NNX-specific backend features. - """ - from keras.src.backend.common import global_state - - global_state.set_global_attribute(_NNX_ENABLED_KEY, False) - - @keras_export("keras.config.is_nnx_backend_enabled") def is_nnx_backend_enabled(): """Checks whether NNX specific features are enabled for the JAX backend. @@ -459,4 +436,11 @@ def max_steps_per_epoch(): # Except permission denied. pass +if "KERAS_NNX_ENABLED" in os.environ: + env_val = os.environ["KERAS_NNX_ENABLED"].lower() + if env_val == "true": + _initial_nnx_value_from_config = True + elif env_val == "false": + _initial_nnx_value_from_config = False + set_nnx_backend_enabled(_initial_nnx_value_from_config) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index b90d881d521b..cf875335a3b8 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -8,12 +8,8 @@ import keras from keras.src import backend -from keras.src import layers from keras.src import testing -from keras.src.backend.jax.core import JaxVariable from keras.src.backend.jax.core import NnxVariable -from keras.src.backend.jax.layer import JaxLayer -from keras.src.backend.jax.layer import NnxLayer @pytest.mark.skipif( @@ -66,54 +62,3 @@ def test_keras_variable_nnx_split_merge_sync(self): state = jax.tree.map(lambda x: x + 1, state) variable2 = nnx.merge(graphdef, state) self.assertEqual(variable2._value, variable2.value) - - -@pytest.mark.skipif( - backend.backend() != "jax", - reason="JAX backend specific test for NNX flag.", -) -class JaxNnxFlagTest(testing.TestCase): - def tearDown(self): - # Reset the flag to its default (False) after each test - keras.config.disable_nnx_backend() - super().tearDown() - - def test_variable_selection_based_on_nnx_flag(self): - # Test with NNX backend enabled - keras.config.enable_nnx_backend() - self.assertTrue(keras.config.is_nnx_backend_enabled()) - var_nnx_enabled = backend.Variable(1.0) - self.assertIsInstance(var_nnx_enabled, NnxVariable) - self.assertNotIsInstance(var_nnx_enabled, JaxVariable) - - # Test with NNX backend disabled - keras.config.disable_nnx_backend() - self.assertFalse(keras.config.is_nnx_backend_enabled()) - var_nnx_disabled = backend.Variable(1.0) - self.assertIsInstance(var_nnx_disabled, JaxVariable) - self.assertNotIsInstance(var_nnx_disabled, NnxVariable) - - def test_layer_backend_selection_based_on_nnx_flag(self): - # Test with NNX backend enabled - keras.config.enable_nnx_backend() - self.assertTrue(keras.config.is_nnx_backend_enabled()) - - class MyLayerNnxEnabled(layers.Layer): - pass - - layer_nnx_enabled = MyLayerNnxEnabled() - self.assertIsInstance(layer_nnx_enabled, NnxLayer) - self.assertNotIsInstance(layer_nnx_enabled, JaxLayer) - - # Test with NNX backend disabled - # Must clear global state to re-evaluate Layer's base class - keras.src.backend.common.global_state.clear_session() - keras.config.disable_nnx_backend() - self.assertFalse(keras.config.is_nnx_backend_enabled()) - - class MyLayerNnxDisabled(layers.Layer): - pass - - layer_nnx_disabled = MyLayerNnxDisabled() - self.assertIsInstance(layer_nnx_disabled, JaxLayer) - self.assertNotIsInstance(layer_nnx_disabled, NnxLayer) From 782c65304ca5430b5d76cbc37785324e4793eb2e Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 18:01:28 +0000 Subject: [PATCH 036/103] adrress some comments --- guides/distributed_training_with_jax.py | 12 +++++++++++- keras/src/backend/config.py | 11 ++++------- keras/src/backend/jax/core.py | 10 +++++++++- keras/src/backend/jax/layer.py | 11 ++++++++++- keras/src/backend/jax/trainer.py | 10 +++++++++- keras/src/ops/operation.py | 13 +++++++++++-- requirements-jax-cuda.txt | 1 - requirements.txt | 1 - 8 files changed, 54 insertions(+), 15 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 90d3491121fa..bfbe94af452b 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -48,12 +48,22 @@ import numpy as np import tensorflow as tf import keras -from flax import nnx + from jax.experimental import mesh_utils from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend import config + +if config.is_nnx_backend_enabled(): + try: + from flax import nnx + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) def get_model(): diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 40fe5ef860a4..31268313ea42 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -15,9 +15,6 @@ # Default backend: TensorFlow. _BACKEND = "tensorflow" -# Default for NNX backend features. -_NNX_ENABLED_KEY = "nnx_enabled" - # Cap run duration for debugging. _MAX_EPOCHS = None _MAX_STEPS_PER_EPOCH = None @@ -243,13 +240,13 @@ def is_nnx_backend_enabled(): """ from keras.src.backend.common import global_state - return global_state.get_global_attribute(_NNX_ENABLED_KEY, default=False) + return global_state.get_global_attribute("nnx_enabled", default=False) def set_nnx_backend_enabled(value: bool): from keras.src.backend.common import global_state - global_state.set_global_attribute(_NNX_ENABLED_KEY, bool(value)) + global_state.set_global_attribute("nnx_enabled", bool(value)) def standardize_data_format(data_format): @@ -298,7 +295,7 @@ def keras_home(): _backend = _config.get("backend", _BACKEND) _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} - _initial_nnx_value_from_config = _config.get(_NNX_ENABLED_KEY, False) + _initial_nnx_value_from_config = _config.get("nnx_enabled", False) if not isinstance(_initial_nnx_value_from_config, bool): _initial_nnx_value_from_config = ( str(_initial_nnx_value_from_config).lower() == "true" @@ -427,7 +424,7 @@ def max_steps_per_epoch(): "epsilon": epsilon(), "backend": _BACKEND, # Use the final _BACKEND value "image_data_format": image_data_format(), - _NNX_ENABLED_KEY: _current_nnx_status_for_saving, + "nnx_enabled": _current_nnx_status_for_saving, } try: with open(_config_path, "w") as f: diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index c0b80d826c96..4d8171f09fba 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -3,9 +3,9 @@ import jax.numpy as jnp import ml_dtypes import numpy as np -from flax import nnx from keras.src import tree +from keras.src.backend import config from keras.src.backend.common import KerasVariable from keras.src.backend.common import global_state from keras.src.backend.common import standardize_dtype @@ -17,6 +17,14 @@ from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib +if config.is_nnx_backend_enabled(): + try: + from flax import nnx + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) SUPPORTS_SPARSE_TENSORS = True SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 8d1cd6242bb1..d5135f2c8987 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,4 +1,13 @@ -from flax import nnx +from keras.src.backend import config + +if config.is_nnx_backend_enabled(): + try: + from flax import nnx + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) class JaxLayer: diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 759bdbbdcf68..cac1af79c700 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -5,7 +5,6 @@ import jax import numpy as np -from flax import nnx from keras.src import backend from keras.src import callbacks as callbacks_module @@ -21,6 +20,15 @@ from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils +if config.is_nnx_backend_enabled(): + try: + from flax import nnx + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) + class JAXTrainer(base_trainer.Trainer): def __init__(self): diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index de4786a7c992..30e58ca5a3cb 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -6,6 +6,7 @@ from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils @@ -123,9 +124,17 @@ def __new__(cls, *args, **kwargs): """ instance = super(Operation, cls).__new__(cls) if backend.backend() == "jax": - from flax import nnx + if is_nnx_backend_enabled(): + try: + from flax import nnx + + vars(instance)["_object__state"] = nnx.object.ObjectState() + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`. " + "Please install it via `pip install flax`." + ) - vars(instance)["_object__state"] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index d2358c0c0cde..1e76a9dfe70c 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,6 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index c76360c7a501..34dd67e1d0f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax>=0.10.1 # Common deps. -r requirements-common.txt From 561f70abed2296e6509f62aa9293483a80f0e5d4 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 18:37:18 +0000 Subject: [PATCH 037/103] update conditional imports --- keras/src/backend/config.py | 29 +-- keras/src/backend/jax/core.py | 448 +++++++++++++++++---------------- keras/src/backend/jax/layer.py | 18 +- 3 files changed, 254 insertions(+), 241 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 31268313ea42..450eb3ed331d 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -15,6 +15,9 @@ # Default backend: TensorFlow. _BACKEND = "tensorflow" +# Whether NNX is enabled. +_NNX_ENABLED = False + # Cap run duration for debugging. _MAX_EPOCHS = None _MAX_STEPS_PER_EPOCH = None @@ -238,14 +241,14 @@ def is_nnx_backend_enabled(): bool: `True` if NNX backend features are enabled, `False` otherwise. Defaults to `False`. """ - from keras.src.backend.common import global_state - - return global_state.get_global_attribute("nnx_enabled", default=False) + return _NNX_ENABLED def set_nnx_backend_enabled(value: bool): + global _NNX_ENABLED from keras.src.backend.common import global_state + _NNX_ENABLED = bool(value) global_state.set_global_attribute("nnx_enabled", bool(value)) @@ -280,7 +283,6 @@ def keras_home(): # Attempt to read Keras config file. _config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) -_initial_nnx_value_from_config = False # Safe default before reading config if os.path.exists(_config_path): try: @@ -295,11 +297,11 @@ def keras_home(): _backend = _config.get("backend", _BACKEND) _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} - _initial_nnx_value_from_config = _config.get("nnx_enabled", False) - if not isinstance(_initial_nnx_value_from_config, bool): - _initial_nnx_value_from_config = ( - str(_initial_nnx_value_from_config).lower() == "true" - ) + _nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED) + if not isinstance(_nnx_enabled_config, bool): + _NNX_ENABLED = str(_nnx_enabled_config).lower() == "true" + else: + _NNX_ENABLED = _nnx_enabled_config # Apply basic configs that don't cause circular import set_floatx(_floatx) @@ -418,13 +420,12 @@ def max_steps_per_epoch(): pass if not os.path.exists(_config_path): - _current_nnx_status_for_saving = is_nnx_backend_enabled() _config_to_save = { "floatx": floatx(), "epsilon": epsilon(), "backend": _BACKEND, # Use the final _BACKEND value "image_data_format": image_data_format(), - "nnx_enabled": _current_nnx_status_for_saving, + "nnx_enabled": _NNX_ENABLED, } try: with open(_config_path, "w") as f: @@ -436,8 +437,8 @@ def max_steps_per_epoch(): if "KERAS_NNX_ENABLED" in os.environ: env_val = os.environ["KERAS_NNX_ENABLED"].lower() if env_val == "true": - _initial_nnx_value_from_config = True + _NNX_ENABLED = True elif env_val == "false": - _initial_nnx_value_from_config = False + _NNX_ENABLED = False -set_nnx_backend_enabled(_initial_nnx_value_from_config) +set_nnx_backend_enabled(_NNX_ENABLED) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 4d8171f09fba..755b314a672b 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -17,14 +17,6 @@ from keras.src.backend.common.symbolic_scope import SymbolicScope from keras.src.backend.jax import distribution_lib -if config.is_nnx_backend_enabled(): - try: - from flax import nnx - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) SUPPORTS_SPARSE_TENSORS = True SUPPORTS_RAGGED_TENSORS = False IS_THREAD_SAFE = True @@ -66,226 +58,246 @@ def __jax_array__(self): return self.value -class NnxVariable(KerasVariable, nnx.Variable): - def __init__( - self, - initializer, - shape=None, - dtype=None, - trainable=True, - autocast=True, - aggregation="none", - synchronization="auto", - name=None, - layout=None, - mutable=None, - **nnx_metadata, - ): - # Determine NNX mutability. This needs to be known for - # nnx.Variable.__init__. - if mutable is None: - actual_nnx_mutable = ( - trainable # Keras 'trainable' maps to NNX 'mutable' - ) - else: - actual_nnx_mutable = mutable - - # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' param - # takes precedence. - if "mutable" in nnx_metadata and mutable is not None: - nnx_metadata["mutable"] = actual_nnx_mutable - elif "mutable" not in nnx_metadata: - nnx_metadata["mutable"] = actual_nnx_mutable - - # Initialize nnx.Variable first. - if shape is not None and dtype is not None: - # If initializer is a Keras callable, it's not ready yet. - # If initializer is already a value, KerasVariable will handle it. - # We need a concrete array for the placeholder. - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(dtype) - ) - elif shape is not None: - _placeholder_value = jnp.zeros(shape, dtype=jnp.float32) - else: - _placeholder_value = jnp.array(0.0, dtype=jnp.float32) - - # Call nnx.Variable.__init__ directly. - nnx.Variable.__init__(self, value=_placeholder_value, **nnx_metadata) - - # Store JAX-specific layout using object.__setattr__ BEFORE - # KerasVariable init. - # This is because KerasVariable.__init__ will call self._initialize, - # which uses self._layout. - object.__setattr__(self, "_layout", layout) - - # Initialize KerasVariable. - super(NnxVariable, self).__init__( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=trainable, # Pass Keras trainable - autocast=autocast, - aggregation=aggregation, - synchronization=synchronization, - name=name, - ) - - @property - def _value(self): - if hasattr(self, "raw_value"): - return self.raw_value - return None - - @_value.setter - def _value(self, new_keras_value): - self._direct_assign(new_keras_value) - - def __getstate__(self): - # Get the state from KerasVariable (attributes in __dict__) - # KerasVariable does not have a custom __getstate__, so we mimic - # default behavior. - keras_state = self.__dict__.copy() - - # Get the state from nnx.Variable - nnx_specific_state = super(KerasVariable, self).__getstate__() - - # Merge them. Keras state is primary. NNX specific state adds to it. - if "raw_value" in nnx_specific_state: - keras_state["_value"] = nnx_specific_state["raw_value"] - - # Add NNX attributes that are not in Keras's __dict__ - if "_trace_state" in nnx_specific_state: - keras_state["_trace_state"] = nnx_specific_state["_trace_state"] - if "_var_metadata" in nnx_specific_state: - keras_state["_var_metadata"] = nnx_specific_state["_var_metadata"] - - # Remove elements that might be problematic or redundant if - # nnx.Variable's __getstate__ - keras_state.pop("raw_value", None) - - return keras_state - - def __setstate__(self, state): - # Separate nnx specific keys that we added if they are not part of - # Keras __dict__ - # Our __getstate__ puts them into the main state dictionary. - nnx_raw_value = state["_value"] # This was raw_value - nnx_trace_state = state.pop("_trace_state", None) - nnx_var_metadata = state.pop("_var_metadata", None) - - # Populate the instance's __dict__ with the Keras attributes. - self.__dict__.update(state) - - # restore the nnx.Variable specific slotted attributes. - object.__setattr__(self, "raw_value", nnx_raw_value) - - if nnx_trace_state is not None: - object.__setattr__(self, "_trace_state", nnx_trace_state) - else: - pass - - if nnx_var_metadata is not None: - object.__setattr__(self, "_var_metadata", nnx_var_metadata) - else: - pass - - # Ensure Keras's self._value is also consistent with the restored - # raw_value - object.__setattr__(self, "_value", nnx_raw_value) - - if hasattr(self, "_shape") and self._shape is not None: - self._ndim = len(self._shape) - else: - # Fallback if shape isn't immediately available. - self._ndim = len(self.raw_value.shape) +_JAX_VARIABLE_TYPES = (JaxVariable,) +if config.is_nnx_backend_enabled(): + try: + from flax import nnx - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) - # We can't import the keras/distribution/distribution_lib - # due to circular dependency. - distribution = global_state.get_global_attribute("distribution") - if self._layout is None and distribution is not None: - tensor_layout = distribution.get_variable_layout(self) - from keras.src.distribution import TensorLayout + class NnxVariable(KerasVariable, nnx.Variable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + layout=None, + mutable=None, + **nnx_metadata, + ): + # Determine NNX mutability. This needs to be known for + # nnx.Variable.__init__. + if mutable is None: + actual_nnx_mutable = ( + trainable # Keras 'trainable' maps to NNX 'mutable' + ) + else: + actual_nnx_mutable = mutable + + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' + # param takes precedence. + if "mutable" in nnx_metadata and mutable is not None: + nnx_metadata["mutable"] = actual_nnx_mutable + elif "mutable" not in nnx_metadata: + nnx_metadata["mutable"] = actual_nnx_mutable + + # Initialize nnx.Variable first. + if shape is not None and dtype is not None: + # If initializer is a Keras callable, it's not ready yet. + # If initializer is already a value, KerasVariable will + # handle it. We need a concrete array for the placeholder. + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(dtype) + ) + elif shape is not None: + _placeholder_value = jnp.zeros(shape, dtype=jnp.float32) + else: + _placeholder_value = jnp.array(0.0, dtype=jnp.float32) - if isinstance(tensor_layout, TensorLayout): - self._layout = tensor_layout.backend_layout - else: - self._layout = tensor_layout - self._direct_assign(value) + # Call nnx.Variable.__init__ directly. + nnx.Variable.__init__( + self, value=_placeholder_value, **nnx_metadata + ) - def _direct_assign(self, value): - # Apply JAX-specific distribution if layout is present - if self._layout is not None: - processed_value = distribution_lib.distribute_variable( - value, self._layout - ) - else: - processed_value = value - - # Ensure that nnx.Variable part is initialized - if not hasattr(self, "_var_metadata"): - # todo: should add a warning - pass - - # Apply on_set_value hook if it exists - if ( - hasattr(self, "_var_metadata") - and "on_set_value" in self._var_metadata - ): - final_value = self._var_metadata["on_set_value"]( - self, processed_value - ) - else: - final_value = processed_value + # Store JAX-specific layout using object.__setattr__ BEFORE + # KerasVariable init. + # This is because KerasVariable.__init__ will call + # self._initialize, which uses self._layout. + object.__setattr__(self, "_layout", layout) + + # Initialize KerasVariable. + super(NnxVariable, self).__init__( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) - # Directly set raw_value. nnx.Variable handles mutable array updates - object.__setattr__(self, "raw_value", final_value) + @property + def _value(self): + if hasattr(self, "raw_value"): + return self.raw_value + return None + + @_value.setter + def _value(self, new_keras_value): + self._direct_assign(new_keras_value) + + def __getstate__(self): + # Get the state from KerasVariable (attributes in __dict__) + # KerasVariable does not have a custom __getstate__, so we mimic + # default behavior. + keras_state = self.__dict__.copy() + + # Get the state from nnx.Variable + nnx_specific_state = super(KerasVariable, self).__getstate__() + + # Merge them. Keras state is primary. NNX specific state adds + # to it. + if "raw_value" in nnx_specific_state: + keras_state["_value"] = nnx_specific_state["raw_value"] + + # Add NNX attributes that are not in Keras's __dict__ + if "_trace_state" in nnx_specific_state: + keras_state["_trace_state"] = nnx_specific_state[ + "_trace_state" + ] + if "_var_metadata" in nnx_specific_state: + keras_state["_var_metadata"] = nnx_specific_state[ + "_var_metadata" + ] + + # Remove elements that might be problematic or redundant if + # nnx.Variable's __getstate__ + keras_state.pop("raw_value", None) + + return keras_state + + def __setstate__(self, state): + # Separate nnx specific keys that we added if they are not part + # of Keras __dict__ this __getstate__ puts them into the main + # state dictionary. + nnx_raw_value = state["_value"] # This was raw_value + nnx_trace_state = state.pop("_trace_state", None) + nnx_var_metadata = state.pop("_var_metadata", None) + + # Populate the instance's __dict__ with the Keras attributes. + self.__dict__.update(state) + + # restore the nnx.Variable specific slotted attributes. + object.__setattr__(self, "raw_value", nnx_raw_value) + + if nnx_trace_state is not None: + object.__setattr__(self, "_trace_state", nnx_trace_state) + else: + pass - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype, sparse=False) + if nnx_var_metadata is not None: + object.__setattr__(self, "_var_metadata", nnx_var_metadata) + else: + pass - # Overload native accessor. - def __jax_array__(self): - return self.value + # Ensure Keras's self._value is also consistent with the + # restored raw_value + object.__setattr__(self, "_value", nnx_raw_value) - @property - def value(self): - if not hasattr(self, "raw_value"): - if not hasattr(self, "_value") or self._value is None: - if self._initializer is not None: - initial_value = self._initializer( - self._shape, dtype=self._dtype - ) - return self._maybe_autocast(initial_value) + if hasattr(self, "_shape") and self._shape is not None: + self._ndim = len(self._shape) else: - raise AttributeError( - "Variable is not properly initialized and has no " - "initializer." + # Fallback if shape isn't immediately available. + self._ndim = len(self.raw_value.shape) + + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout + + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout + else: + self._layout = tensor_layout + self._direct_assign(value) + + def _direct_assign(self, value): + # Apply JAX-specific distribution if layout is present + if self._layout is not None: + processed_value = distribution_lib.distribute_variable( + value, self._layout ) - current_value = self._value - else: - current_value = self.raw_value - if ( - hasattr(self, "_var_metadata") - and "on_get_value" in self._var_metadata - ): - current_value = self._var_metadata["on_get_value"]( - self, current_value - ) - - if in_stateless_scope(): - scope = get_stateless_scope() - stateless_value = scope.get_current_value(self) - if stateless_value is not None: - return self._maybe_autocast(stateless_value) + else: + processed_value = value - return self._maybe_autocast(current_value) + # Ensure that nnx.Variable part is initialized + if not hasattr(self, "_var_metadata"): + # todo: should add a warning + pass - def __hash__(self): - return id(self) + # Apply on_set_value hook if it exists + if ( + hasattr(self, "_var_metadata") + and "on_set_value" in self._var_metadata + ): + final_value = self._var_metadata["on_set_value"]( + self, processed_value + ) + else: + final_value = processed_value + + # Directly set raw_value. nnx.Variable handles mutable array + # updates + object.__setattr__(self, "raw_value", final_value) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype, sparse=False) + + # Overload native accessor. + def __jax_array__(self): + return self.value + + @property + def value(self): + if not hasattr(self, "raw_value"): + if not hasattr(self, "_value") or self._value is None: + if self._initializer is not None: + initial_value = self._initializer( + self._shape, dtype=self._dtype + ) + return self._maybe_autocast(initial_value) + else: + raise AttributeError( + "Variable is not properly initialized and has" + " no initializer." + ) + current_value = self._value + else: + current_value = self.raw_value + if ( + hasattr(self, "_var_metadata") + and "on_get_value" in self._var_metadata + ): + current_value = self._var_metadata["on_get_value"]( + self, current_value + ) + + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) + + return self._maybe_autocast(current_value) + + def __hash__(self): + return id(self) + + _JAX_VARIABLE_TYPES += (NnxVariable,) + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): @@ -301,7 +313,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): # an existing distributed jax array will raise error. return x - if isinstance(x, (JaxVariable, NnxVariable)): + if isinstance(x, _JAX_VARIABLE_TYPES): if dtype is not None and x.dtype != dtype: return x.value.astype(dtype) return x.value diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index d5135f2c8987..e1a8ab9e4ee7 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,19 +1,19 @@ from keras.src.backend import config + +class JaxLayer: + pass + + if config.is_nnx_backend_enabled(): try: from flax import nnx + + class NnxLayer(nnx.Module): + def __init_subclass__(cls): + super().__init_subclass__() except ImportError: raise ImportError( "To use the NNX backend, you must install `flax`." "Try: `pip install flax`" ) - - -class JaxLayer: - pass - - -class NnxLayer(nnx.Module): - def __init_subclass__(cls): - super().__init_subclass__() From e7caa03d48b0e387da2bb931b3533c56a0f1eeba Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 18:51:09 +0000 Subject: [PATCH 038/103] fix tests --- keras/src/backend/config.py | 3 +-- keras/src/backend/jax/core_test.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 450eb3ed331d..b0d5d2e2b299 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -308,8 +308,7 @@ def keras_home(): set_epsilon(_epsilon) set_image_data_format(_image_data_format) _BACKEND = _backend -else: - _config = {} + if "KERAS_BACKEND" in os.environ: _backend = os.environ["KERAS_BACKEND"] if _backend: diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index cf875335a3b8..d7e6f11def06 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -4,13 +4,22 @@ import jax.numpy as jnp import numpy as np import pytest -from flax import nnx import keras from keras.src import backend from keras.src import testing +from keras.src.backend import config from keras.src.backend.jax.core import NnxVariable +if config.is_nnx_backend_enabled(): + try: + from flax import nnx + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) + @pytest.mark.skipif( backend.backend() != "jax", From 8e3f460ddeb12ce198bc959654aaa8551bc318e8 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 19:12:28 +0000 Subject: [PATCH 039/103] add github workflow for nnx --- .github/workflows/actions.yml | 6 +++++- .github/workflows/config/nnx/keras.json | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/config/nnx/keras.json diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index b9e785dfc949..f9c15cac9706 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10'] - backend: [tensorflow, jax, torch, numpy, openvino] + backend: [tensorflow, jax, torch, numpy, openvino, nnx] name: Run tests runs-on: ubuntu-latest env: @@ -50,6 +50,10 @@ jobs: pip install -r requirements.txt --progress-bar off --upgrade pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade + - name: Install Flax for NNX backend + if: matrix.backend == 'nnx' + run: | + pip install flax --progress-bar off --upgrade - name: Test applications with pytest if: ${{ steps.filter.outputs.applications == 'true' }} run: | diff --git a/.github/workflows/config/nnx/keras.json b/.github/workflows/config/nnx/keras.json new file mode 100644 index 000000000000..d6bb3e7fd4d5 --- /dev/null +++ b/.github/workflows/config/nnx/keras.json @@ -0,0 +1,7 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "jax", + "image_data_format": "channels_last", + "nnx_enabled": true +} From d70d51c1f1261971b55d3e57305ee24caac8138c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 19:16:32 +0000 Subject: [PATCH 040/103] fix test --- keras/src/backend/jax/core_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index d7e6f11def06..936acd338d88 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -9,11 +9,12 @@ from keras.src import backend from keras.src import testing from keras.src.backend import config -from keras.src.backend.jax.core import NnxVariable if config.is_nnx_backend_enabled(): try: from flax import nnx + + from keras.src.backend.jax.core import NnxVariable except ImportError: raise ImportError( "To use the NNX backend, you must install `flax`." From 38dbd4bbde3df6f14b0cb3fc4c4bfaeb6f9a3d73 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 22:42:27 +0000 Subject: [PATCH 041/103] address comments --- guides/distributed_training_with_jax.py | 30 +++-------------------- keras/src/backend/config.py | 12 ++++++++-- keras/src/backend/jax/core.py | 8 +++---- keras/src/backend/jax/core_test.py | 14 ++--------- keras/src/backend/jax/layer.py | 17 ++++--------- keras/src/backend/jax/trainer.py | 32 +++++-------------------- keras/src/layers/layer.py | 13 ++++++++++ keras/src/ops/operation.py | 12 ---------- keras/src/utils/jax_utils.py | 14 +++++++++++ requirements-common.txt | 1 - requirements-jax-cuda.txt | 2 +- requirements.txt | 2 +- 12 files changed, 58 insertions(+), 99 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index bfbe94af452b..9ce5cf42f948 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -54,16 +54,8 @@ from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P from keras.src.backend.config import is_nnx_backend_enabled -from keras.src.backend import config - -if config.is_nnx_backend_enabled(): - try: - from flax import nnx - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) +from keras.src.utils.jax_utils import jit +from flax import nnx def get_model(): @@ -196,24 +188,8 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): compute_gradients = jax.value_and_grad(compute_loss, has_aux=True) -def conditional_jit(condition, *args, **kwargs): - """ - Applies @jax.jit or @nnx.jit based on a condition. - """ - - def decorator(func): - if condition: - print("Using @nnx.jit") - return nnx.jit(func, *args, **kwargs) - else: - print("Using @jax.jit") - return jax.jit(func, *args, **kwargs) - - return decorator - - # Training step, Keras provides a pure functional optimizer.stateless_apply -@conditional_jit(is_nnx_backend_enabled()) +@jit() def train_step(train_state, x, y): ( trainable_variables, diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index b0d5d2e2b299..f3ee55ae077c 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -244,11 +244,19 @@ def is_nnx_backend_enabled(): return _NNX_ENABLED -def set_nnx_backend_enabled(value: bool): +def set_nnx_enabled(value: bool): global _NNX_ENABLED from keras.src.backend.common import global_state _NNX_ENABLED = bool(value) + if _NNX_ENABLED: + try: + from flax import nnx # noqa F401 + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) global_state.set_global_attribute("nnx_enabled", bool(value)) @@ -440,4 +448,4 @@ def max_steps_per_epoch(): elif env_val == "false": _NNX_ENABLED = False -set_nnx_backend_enabled(_NNX_ENABLED) +set_nnx_enabled(_NNX_ENABLED) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 755b314a672b..a6891af289c8 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -58,7 +58,7 @@ def __jax_array__(self): return self.value -_JAX_VARIABLE_TYPES = (JaxVariable,) +_JAX_VARIABLE_TYPE = JaxVariable if config.is_nnx_backend_enabled(): try: from flax import nnx @@ -292,7 +292,7 @@ def value(self): def __hash__(self): return id(self) - _JAX_VARIABLE_TYPES += (NnxVariable,) + _JAX_VARIABLE_TYPE = NnxVariable except ImportError: raise ImportError( "To use the NNX backend, you must install `flax`." @@ -313,7 +313,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): # an existing distributed jax array will raise error. return x - if isinstance(x, _JAX_VARIABLE_TYPES): + if isinstance(x, _JAX_VARIABLE_TYPE): if dtype is not None and x.dtype != dtype: return x.value.astype(dtype) return x.value @@ -597,7 +597,7 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): - if isinstance(variable, (JaxVariable, NnxVariable)): + if isinstance(variable, _JAX_VARIABLE_TYPE): variable = variable.value return jax.lax.stop_gradient(variable) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 936acd338d88..cf875335a3b8 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -4,22 +4,12 @@ import jax.numpy as jnp import numpy as np import pytest +from flax import nnx import keras from keras.src import backend from keras.src import testing -from keras.src.backend import config - -if config.is_nnx_backend_enabled(): - try: - from flax import nnx - - from keras.src.backend.jax.core import NnxVariable - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) +from keras.src.backend.jax.core import NnxVariable @pytest.mark.skipif( diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index e1a8ab9e4ee7..8d1cd6242bb1 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,19 +1,10 @@ -from keras.src.backend import config +from flax import nnx class JaxLayer: pass -if config.is_nnx_backend_enabled(): - try: - from flax import nnx - - class NnxLayer(nnx.Module): - def __init_subclass__(cls): - super().__init_subclass__() - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) +class NnxLayer(nnx.Module): + def __init_subclass__(cls): + super().__init_subclass__() diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index cac1af79c700..8e60ebfa167a 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -12,22 +12,13 @@ from keras.src import tree from keras.src.backend import config from keras.src.backend import distribution_lib as jax_distribution_lib -from keras.src.backend.config import is_nnx_backend_enabled from keras.src.distribution import distribution_lib from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils - -if config.is_nnx_backend_enabled(): - try: - from flax import nnx - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) +from keras.src.utils.jax_utils import jit class JAXTrainer(base_trainer.Trainer): @@ -243,10 +234,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - if is_nnx_backend_enabled(): - concatenate = nnx.jit(concatenate) - else: - concatenate = jax.jit(concatenate) + concatenate = jit(concatenate) def iterator_step(state, iterator): data = next(iterator) @@ -290,10 +278,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - if is_nnx_backend_enabled(): - train_step = nnx.jit(self.train_step, donate_argnums=0) - else: - train_step = jax.jit(self.train_step, donate_argnums=0) + train_step = jit(self.train_step, donate_argnums=0) else: train_step = self.train_step @@ -309,10 +294,8 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - if is_nnx_backend_enabled(): - test_step = nnx.jit(self.test_step, donate_argnums=0) - else: - test_step = jax.jit(self.test_step, donate_argnums=0) + test_step = jit(self.test_step, donate_argnums=0) + else: test_step = self.test_step @@ -329,10 +312,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - if is_nnx_backend_enabled(): - predict_step = nnx.jit(predict_step, donate_argnums=0) - else: - predict_step = jax.jit(predict_step, donate_argnums=0) + predict_step = jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 1eff520ba8a2..a5b3b85e8986 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -222,9 +222,22 @@ def call(self, inputs): ``` """ + def __init_subclass__(cls): + super().__init_subclass__() + def __new__(cls, *args, **kwargs): obj = super().__new__(cls, *args, **kwargs) + instance = super(Layer, cls).__new__(cls) + if backend.backend() == "jax" and is_nnx_backend_enabled(): + try: + from flax import nnx + vars(instance)["_object__state"] = nnx.object.ObjectState() + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`. " + "Please install it via `pip install flax`." + ) # Wrap the user-provided `build` method in the `build_wrapper` # to add name scope support and serialization support. original_build_method = obj.build diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 30e58ca5a3cb..46eb4b25f533 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -6,7 +6,6 @@ from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors -from keras.src.backend.config import is_nnx_backend_enabled from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils @@ -123,17 +122,6 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) - if backend.backend() == "jax": - if is_nnx_backend_enabled(): - try: - from flax import nnx - - vars(instance)["_object__state"] = nnx.object.ObjectState() - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`. " - "Please install it via `pip install flax`." - ) # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index d5375785f762..8ea7a1caf577 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -1,4 +1,8 @@ +import jax +from flax import nnx + from keras.src import backend +from keras.src.backend.config import is_nnx_backend_enabled def is_in_jax_tracing_scope(x=None): @@ -9,3 +13,13 @@ def is_in_jax_tracing_scope(x=None): if c.__name__ == "Tracer" and c.__module__.startswith("jax"): return True return False + + +def jit(*args, **kwargs): + def decorator(func): + if is_nnx_backend_enabled(): + return nnx.jit(func, *args, **kwargs) + else: + return jax.jit(func, *args, **kwargs) + + return decorator diff --git a/requirements-common.txt b/requirements-common.txt index 21ec0efe7cdd..7edc40c97a1a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,4 +24,3 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino -flax>=0.10.1 diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 1e76a9dfe70c..765263e82696 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,5 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 - +flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 34dd67e1d0f6..730f1fb2601c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 - +flax>=0.10.1 # Common deps. -r requirements-common.txt From 1c60c5ef86c5d3c030adf79da0ef1b6f3fde9f30 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 22:56:57 +0000 Subject: [PATCH 042/103] fix test --- keras/src/backend/jax/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 8e60ebfa167a..975564fd0909 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -234,7 +234,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - concatenate = jit(concatenate) + concatenate = jit()(concatenate) def iterator_step(state, iterator): data = next(iterator) @@ -278,7 +278,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - train_step = jit(self.train_step, donate_argnums=0) + train_step = jit(donate_argnums=0)(self.train_step) else: train_step = self.train_step @@ -294,7 +294,7 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - test_step = jit(self.test_step, donate_argnums=0) + test_step = jit(donate_argnums=0)(self.test_step) else: test_step = self.test_step @@ -312,7 +312,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jit(predict_step, donate_argnums=0) + predict_step = jit(donate_argnums=0)(predict_step) _step_function = self._make_function( predict_step, concatenate_outputs=True From 74835fd4663af98072a524fd29966af30d4a25c4 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 23:31:26 +0000 Subject: [PATCH 043/103] address comments --- examples/demo_jax_distributed.py | 1 + keras/src/backend/jax/core.py | 29 +++++++++++++++++++---------- keras/src/utils/jax_utils.py | 7 ++++--- 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/examples/demo_jax_distributed.py b/examples/demo_jax_distributed.py index b54b3cc2f1db..906dc47563de 100644 --- a/examples/demo_jax_distributed.py +++ b/examples/demo_jax_distributed.py @@ -12,6 +12,7 @@ import jax.numpy as jnp import tensorflow as tf # just for tf.data import keras # Keras multi-backend + import numpy as np from tqdm import tqdm diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index a6891af289c8..c72436c5f19b 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -95,17 +95,26 @@ def __init__( nnx_metadata["mutable"] = actual_nnx_mutable # Initialize nnx.Variable first. - if shape is not None and dtype is not None: - # If initializer is a Keras callable, it's not ready yet. - # If initializer is already a value, KerasVariable will - # handle it. We need a concrete array for the placeholder. - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(dtype) - ) - elif shape is not None: - _placeholder_value = jnp.zeros(shape, dtype=jnp.float32) + # Determine the dtype for the placeholder. + _placeholder_value = None + if shape is not None: + if dtype is not None: + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(dtype) + ) + else: + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(config.floatx()) + ) else: - _placeholder_value = jnp.array(0.0, dtype=jnp.float32) + if dtype is not None: + _placeholder_value = jnp.array( + 0.0, dtype=standardize_dtype(dtype) + ) + else: + _placeholder_value = jnp.array( + 0.0, dtype=standardize_dtype(config.floatx()) + ) # Call nnx.Variable.__init__ directly. nnx.Variable.__init__( diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index 8ea7a1caf577..56b164d18125 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -1,6 +1,3 @@ -import jax -from flax import nnx - from keras.src import backend from keras.src.backend.config import is_nnx_backend_enabled @@ -16,8 +13,12 @@ def is_in_jax_tracing_scope(x=None): def jit(*args, **kwargs): + import jax + def decorator(func): if is_nnx_backend_enabled(): + from flax import nnx + return nnx.jit(func, *args, **kwargs) else: return jax.jit(func, *args, **kwargs) From 6f11c0c897d05035f18f9672317786b3a00e038a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 6 Jun 2025 23:41:00 +0000 Subject: [PATCH 044/103] fix test --- keras/src/backend/jax/core_test.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index cf875335a3b8..b961f574992c 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -4,20 +4,23 @@ import jax.numpy as jnp import numpy as np import pytest -from flax import nnx import keras from keras.src import backend from keras.src import testing +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.backend.jax.core import NnxVariable +if is_nnx_backend_enabled(): + from flax import nnx + @pytest.mark.skipif( backend.backend() != "jax", reason="JAX backend specific test for core Variable integration with NNX.", ) @pytest.mark.skipif( - not keras.config.is_nnx_backend_enabled(), + not is_nnx_backend_enabled(), reason="Test requires NNX backend to be enabled by default for setup.", ) class JaxCoreVariableTest(testing.TestCase): From c05166ea6cbe672735e198dc395f185ee924145d Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Sat, 7 Jun 2025 00:08:32 +0000 Subject: [PATCH 045/103] fix test -_- --- keras/src/backend/jax/core_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index b961f574992c..77c1727c788e 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -9,11 +9,12 @@ from keras.src import backend from keras.src import testing from keras.src.backend.config import is_nnx_backend_enabled -from keras.src.backend.jax.core import NnxVariable if is_nnx_backend_enabled(): from flax import nnx + from keras.src.backend.jax.core import NnxVariable + @pytest.mark.skipif( backend.backend() != "jax", From 8582c7ec2e95b91dc85b06836ab25bedfc042c06 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 11 Jun 2025 23:19:32 +0000 Subject: [PATCH 046/103] put the set attr in operation --- keras/src/layers/layer.py | 11 ----------- keras/src/ops/operation.py | 12 +++++++++++- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index a5b3b85e8986..926eca6e4435 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -227,17 +227,6 @@ def __init_subclass__(cls): def __new__(cls, *args, **kwargs): obj = super().__new__(cls, *args, **kwargs) - instance = super(Layer, cls).__new__(cls) - if backend.backend() == "jax" and is_nnx_backend_enabled(): - try: - from flax import nnx - - vars(instance)["_object__state"] = nnx.object.ObjectState() - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`. " - "Please install it via `pip install flax`." - ) # Wrap the user-provided `build` method in the `build_wrapper` # to add name scope support and serialization support. original_build_method = obj.build diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 46eb4b25f533..e8a322ac5a82 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -6,6 +6,7 @@ from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors +from keras.src.backend.config import is_nnx_backend_enabled from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils @@ -122,7 +123,16 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) - + if backend.backend() == "jax" and is_nnx_backend_enabled(): + try: + from flax import nnx + + vars(instance)["_object__state"] = nnx.object.ObjectState() + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`. " + "Please install it via `pip install flax`." + ) # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) From 297775a792ff689b2a834c0216ae5a18fbc2f8e1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 11 Jun 2025 23:51:09 +0000 Subject: [PATCH 047/103] fix jax error --- keras/src/layers/layer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index e527f6a78713..24711f25ba05 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -54,8 +54,16 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": + if is_nnx_backend_enabled(): - from keras.src.backend.jax.layer import NnxLayer as BackendLayer + try: + from flax import nnx # noqa F401 + from keras.src.backend.jax.layer import NnxLayer as BackendLayer + except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) else: from keras.src.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": From f01cc0dc6f8f7fa46e1bfa774e1cea19b28215e3 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:06:31 +0000 Subject: [PATCH 048/103] fix trace error --- keras/src/backend/jax/layer.py | 8 +++++++- keras/src/layers/layer.py | 15 +++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 8d1cd6242bb1..e5fcfdbebe34 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,4 +1,10 @@ -from flax import nnx +try: + from flax import nnx # noqa F401 +except ImportError: + raise ImportError( + "To use the NNX backend, you must install `flax`." + "Try: `pip install flax`" + ) class JaxLayer: diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 24711f25ba05..7236485a6b1f 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -54,7 +54,6 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - if is_nnx_backend_enabled(): try: from flax import nnx # noqa F401 @@ -1547,7 +1546,19 @@ def __setattr__(self, name, value): if not hasattr(self, "_tracker"): self._initialize_tracker() value = self._tracker.track(value) - return super().__setattr__(name, value) + + # NNX-specific bypass for `_called` and `built` attributes + if ( + backend.backend() == "jax" + and is_nnx_backend_enabled() + and (name == "_called" or name == "built") + ): + object.__setattr__(self, name, value) + return + + super().__setattr__( + name, value + ) # Default path, including for NnxLayer -> nnx.Module def __delattr__(self, name): obj = getattr(self, name) From 68108484bde6cf5a1aa99e62a9e00e012eb96eca Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:18:25 +0000 Subject: [PATCH 049/103] remove installation --- .github/workflows/actions.yml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index f9c15cac9706..16bdb759a9c0 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -50,10 +50,6 @@ jobs: pip install -r requirements.txt --progress-bar off --upgrade pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - - name: Install Flax for NNX backend - if: matrix.backend == 'nnx' - run: | - pip install flax --progress-bar off --upgrade - name: Test applications with pytest if: ${{ steps.filter.outputs.applications == 'true' }} run: | From dc793299b8b597cb3f6157d1dc89b3840d0e4b47 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:35:06 +0000 Subject: [PATCH 050/103] import fixes --- keras/src/backend/jax/core.py | 446 ++++++++++++++++----------------- keras/src/backend/jax/layer.py | 8 +- keras/src/layers/layer.py | 9 +- keras/src/ops/operation.py | 12 +- requirements-jax-cuda.txt | 2 +- requirements.txt | 2 +- 6 files changed, 226 insertions(+), 253 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index c72436c5f19b..e36d33e5b1f0 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -60,253 +60,245 @@ def __jax_array__(self): _JAX_VARIABLE_TYPE = JaxVariable if config.is_nnx_backend_enabled(): - try: - from flax import nnx - - class NnxVariable(KerasVariable, nnx.Variable): - def __init__( - self, - initializer, - shape=None, - dtype=None, - trainable=True, - autocast=True, - aggregation="none", - synchronization="auto", - name=None, - layout=None, - mutable=None, - **nnx_metadata, - ): - # Determine NNX mutability. This needs to be known for - # nnx.Variable.__init__. - if mutable is None: - actual_nnx_mutable = ( - trainable # Keras 'trainable' maps to NNX 'mutable' + from flax import nnx + + class NnxVariable(KerasVariable, nnx.Variable): + def __init__( + self, + initializer, + shape=None, + dtype=None, + trainable=True, + autocast=True, + aggregation="none", + synchronization="auto", + name=None, + layout=None, + mutable=None, + **nnx_metadata, + ): + # Determine NNX mutability. This needs to be known for + # nnx.Variable.__init__. + if mutable is None: + actual_nnx_mutable = ( + trainable # Keras 'trainable' maps to NNX 'mutable' + ) + else: + actual_nnx_mutable = mutable + + # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' + # param takes precedence. + if "mutable" in nnx_metadata and mutable is not None: + nnx_metadata["mutable"] = actual_nnx_mutable + elif "mutable" not in nnx_metadata: + nnx_metadata["mutable"] = actual_nnx_mutable + + # Initialize nnx.Variable first. + # Determine the dtype for the placeholder. + _placeholder_value = None + if shape is not None: + if dtype is not None: + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(dtype) ) else: - actual_nnx_mutable = mutable - - # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' - # param takes precedence. - if "mutable" in nnx_metadata and mutable is not None: - nnx_metadata["mutable"] = actual_nnx_mutable - elif "mutable" not in nnx_metadata: - nnx_metadata["mutable"] = actual_nnx_mutable - - # Initialize nnx.Variable first. - # Determine the dtype for the placeholder. - _placeholder_value = None - if shape is not None: - if dtype is not None: - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(dtype) - ) - else: - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(config.floatx()) - ) + _placeholder_value = jnp.zeros( + shape, dtype=standardize_dtype(config.floatx()) + ) + else: + if dtype is not None: + _placeholder_value = jnp.array( + 0.0, dtype=standardize_dtype(dtype) + ) else: - if dtype is not None: - _placeholder_value = jnp.array( - 0.0, dtype=standardize_dtype(dtype) - ) - else: - _placeholder_value = jnp.array( - 0.0, dtype=standardize_dtype(config.floatx()) - ) + _placeholder_value = jnp.array( + 0.0, dtype=standardize_dtype(config.floatx()) + ) - # Call nnx.Variable.__init__ directly. - nnx.Variable.__init__( - self, value=_placeholder_value, **nnx_metadata - ) + # Call nnx.Variable.__init__ directly. + nnx.Variable.__init__( + self, value=_placeholder_value, **nnx_metadata + ) - # Store JAX-specific layout using object.__setattr__ BEFORE - # KerasVariable init. - # This is because KerasVariable.__init__ will call - # self._initialize, which uses self._layout. - object.__setattr__(self, "_layout", layout) - - # Initialize KerasVariable. - super(NnxVariable, self).__init__( - initializer=initializer, - shape=shape, - dtype=dtype, - trainable=trainable, - autocast=autocast, - aggregation=aggregation, - synchronization=synchronization, - name=name, - ) + # Store JAX-specific layout using object.__setattr__ BEFORE + # KerasVariable init. + # This is because KerasVariable.__init__ will call + # self._initialize, which uses self._layout. + object.__setattr__(self, "_layout", layout) + + # Initialize KerasVariable. + super(NnxVariable, self).__init__( + initializer=initializer, + shape=shape, + dtype=dtype, + trainable=trainable, + autocast=autocast, + aggregation=aggregation, + synchronization=synchronization, + name=name, + ) - @property - def _value(self): - if hasattr(self, "raw_value"): - return self.raw_value - return None - - @_value.setter - def _value(self, new_keras_value): - self._direct_assign(new_keras_value) - - def __getstate__(self): - # Get the state from KerasVariable (attributes in __dict__) - # KerasVariable does not have a custom __getstate__, so we mimic - # default behavior. - keras_state = self.__dict__.copy() - - # Get the state from nnx.Variable - nnx_specific_state = super(KerasVariable, self).__getstate__() - - # Merge them. Keras state is primary. NNX specific state adds - # to it. - if "raw_value" in nnx_specific_state: - keras_state["_value"] = nnx_specific_state["raw_value"] - - # Add NNX attributes that are not in Keras's __dict__ - if "_trace_state" in nnx_specific_state: - keras_state["_trace_state"] = nnx_specific_state[ - "_trace_state" - ] - if "_var_metadata" in nnx_specific_state: - keras_state["_var_metadata"] = nnx_specific_state[ - "_var_metadata" - ] - - # Remove elements that might be problematic or redundant if - # nnx.Variable's __getstate__ - keras_state.pop("raw_value", None) - - return keras_state - - def __setstate__(self, state): - # Separate nnx specific keys that we added if they are not part - # of Keras __dict__ this __getstate__ puts them into the main - # state dictionary. - nnx_raw_value = state["_value"] # This was raw_value - nnx_trace_state = state.pop("_trace_state", None) - nnx_var_metadata = state.pop("_var_metadata", None) - - # Populate the instance's __dict__ with the Keras attributes. - self.__dict__.update(state) - - # restore the nnx.Variable specific slotted attributes. - object.__setattr__(self, "raw_value", nnx_raw_value) - - if nnx_trace_state is not None: - object.__setattr__(self, "_trace_state", nnx_trace_state) - else: - pass + @property + def _value(self): + if hasattr(self, "raw_value"): + return self.raw_value + return None + + @_value.setter + def _value(self, new_keras_value): + self._direct_assign(new_keras_value) + + def __getstate__(self): + # Get the state from KerasVariable (attributes in __dict__) + # KerasVariable does not have a custom __getstate__, so we mimic + # default behavior. + keras_state = self.__dict__.copy() + + # Get the state from nnx.Variable + nnx_specific_state = super(KerasVariable, self).__getstate__() + + # Merge them. Keras state is primary. NNX specific state adds + # to it. + if "raw_value" in nnx_specific_state: + keras_state["_value"] = nnx_specific_state["raw_value"] + + # Add NNX attributes that are not in Keras's __dict__ + if "_trace_state" in nnx_specific_state: + keras_state["_trace_state"] = nnx_specific_state["_trace_state"] + if "_var_metadata" in nnx_specific_state: + keras_state["_var_metadata"] = nnx_specific_state[ + "_var_metadata" + ] - if nnx_var_metadata is not None: - object.__setattr__(self, "_var_metadata", nnx_var_metadata) - else: - pass + # Remove elements that might be problematic or redundant if + # nnx.Variable's __getstate__ + keras_state.pop("raw_value", None) - # Ensure Keras's self._value is also consistent with the - # restored raw_value - object.__setattr__(self, "_value", nnx_raw_value) + return keras_state - if hasattr(self, "_shape") and self._shape is not None: - self._ndim = len(self._shape) - else: - # Fallback if shape isn't immediately available. - self._ndim = len(self.raw_value.shape) - - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) - # We can't import the keras/distribution/distribution_lib - # due to circular dependency. - distribution = global_state.get_global_attribute("distribution") - if self._layout is None and distribution is not None: - tensor_layout = distribution.get_variable_layout(self) - from keras.src.distribution import TensorLayout - - if isinstance(tensor_layout, TensorLayout): - self._layout = tensor_layout.backend_layout - else: - self._layout = tensor_layout - self._direct_assign(value) - - def _direct_assign(self, value): - # Apply JAX-specific distribution if layout is present - if self._layout is not None: - processed_value = distribution_lib.distribute_variable( - value, self._layout - ) + def __setstate__(self, state): + # Separate nnx specific keys that we added if they are not part + # of Keras __dict__ this __getstate__ puts them into the main + # state dictionary. + nnx_raw_value = state["_value"] # This was raw_value + nnx_trace_state = state.pop("_trace_state", None) + nnx_var_metadata = state.pop("_var_metadata", None) + + # Populate the instance's __dict__ with the Keras attributes. + self.__dict__.update(state) + + # restore the nnx.Variable specific slotted attributes. + object.__setattr__(self, "raw_value", nnx_raw_value) + + if nnx_trace_state is not None: + object.__setattr__(self, "_trace_state", nnx_trace_state) + else: + pass + + if nnx_var_metadata is not None: + object.__setattr__(self, "_var_metadata", nnx_var_metadata) + else: + pass + + # Ensure Keras's self._value is also consistent with the + # restored raw_value + object.__setattr__(self, "_value", nnx_raw_value) + + if hasattr(self, "_shape") and self._shape is not None: + self._ndim = len(self._shape) + else: + # Fallback if shape isn't immediately available. + self._ndim = len(self.raw_value.shape) + + def _initialize(self, value): + # Note that variable.shape is needed by distribution_lib + self._shape = self._validate_shape(value.shape) + # We can't import the keras/distribution/distribution_lib + # due to circular dependency. + distribution = global_state.get_global_attribute("distribution") + if self._layout is None and distribution is not None: + tensor_layout = distribution.get_variable_layout(self) + from keras.src.distribution import TensorLayout + + if isinstance(tensor_layout, TensorLayout): + self._layout = tensor_layout.backend_layout else: - processed_value = value + self._layout = tensor_layout + self._direct_assign(value) + + def _direct_assign(self, value): + # Apply JAX-specific distribution if layout is present + if self._layout is not None: + processed_value = distribution_lib.distribute_variable( + value, self._layout + ) + else: + processed_value = value - # Ensure that nnx.Variable part is initialized - if not hasattr(self, "_var_metadata"): - # todo: should add a warning - pass + # Ensure that nnx.Variable part is initialized + if not hasattr(self, "_var_metadata"): + # todo: should add a warning + pass - # Apply on_set_value hook if it exists + # Apply on_set_value hook if it exists + if ( + hasattr(self, "_var_metadata") + and "on_set_value" in self._var_metadata + ): + final_value = self._var_metadata["on_set_value"]( + self, processed_value + ) + else: + final_value = processed_value + + # Directly set raw_value. nnx.Variable handles mutable array + # updates + object.__setattr__(self, "raw_value", final_value) + + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype, sparse=False) + + # Overload native accessor. + def __jax_array__(self): + return self.value + + @property + def value(self): + if not hasattr(self, "raw_value"): + if not hasattr(self, "_value") or self._value is None: + if self._initializer is not None: + initial_value = self._initializer( + self._shape, dtype=self._dtype + ) + return self._maybe_autocast(initial_value) + else: + raise AttributeError( + "Variable is not properly initialized and has" + " no initializer." + ) + current_value = self._value + else: + current_value = self.raw_value if ( hasattr(self, "_var_metadata") - and "on_set_value" in self._var_metadata + and "on_get_value" in self._var_metadata ): - final_value = self._var_metadata["on_set_value"]( - self, processed_value + current_value = self._var_metadata["on_get_value"]( + self, current_value ) - else: - final_value = processed_value - - # Directly set raw_value. nnx.Variable handles mutable array - # updates - object.__setattr__(self, "raw_value", final_value) - - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype, sparse=False) - - # Overload native accessor. - def __jax_array__(self): - return self.value - - @property - def value(self): - if not hasattr(self, "raw_value"): - if not hasattr(self, "_value") or self._value is None: - if self._initializer is not None: - initial_value = self._initializer( - self._shape, dtype=self._dtype - ) - return self._maybe_autocast(initial_value) - else: - raise AttributeError( - "Variable is not properly initialized and has" - " no initializer." - ) - current_value = self._value - else: - current_value = self.raw_value - if ( - hasattr(self, "_var_metadata") - and "on_get_value" in self._var_metadata - ): - current_value = self._var_metadata["on_get_value"]( - self, current_value - ) - if in_stateless_scope(): - scope = get_stateless_scope() - stateless_value = scope.get_current_value(self) - if stateless_value is not None: - return self._maybe_autocast(stateless_value) + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) - return self._maybe_autocast(current_value) + return self._maybe_autocast(current_value) - def __hash__(self): - return id(self) + def __hash__(self): + return id(self) - _JAX_VARIABLE_TYPE = NnxVariable - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) + _JAX_VARIABLE_TYPE = NnxVariable def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index e5fcfdbebe34..3d7fb431b013 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,10 +1,4 @@ -try: - from flax import nnx # noqa F401 -except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) +from flax import nnx # noqa F401 class JaxLayer: diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 7236485a6b1f..4605318fcac8 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -55,14 +55,7 @@ from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": if is_nnx_backend_enabled(): - try: - from flax import nnx # noqa F401 - from keras.src.backend.jax.layer import NnxLayer as BackendLayer - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" - ) + from keras.src.backend.jax.layer import NnxLayer as BackendLayer else: from keras.src.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index c7c8b84cbf21..bf42fae895dd 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -1,6 +1,8 @@ import inspect import textwrap +from flax import nnx + from keras.src import backend from keras.src import dtype_policies from keras.src import tree @@ -123,15 +125,7 @@ def __new__(cls, *args, **kwargs): """ instance = super(Operation, cls).__new__(cls) if backend.backend() == "jax" and is_nnx_backend_enabled(): - try: - from flax import nnx - - vars(instance)["_object__state"] = nnx.object.ObjectState() - except ImportError: - raise ImportError( - "To use the NNX backend, you must install `flax`. " - "Please install it via `pip install flax`." - ) + vars(instance)["_object__state"] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args))) diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 765263e82696..32959359dd7d 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,5 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax>=0.10.1 +flax -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 730f1fb2601c..e5a44501e6b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax>=0.10.1 +flax # Common deps. -r requirements-common.txt From f280dd09f78b0355e2cb00d66f5bfcf1e8657668 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:47:53 +0000 Subject: [PATCH 051/103] update jax version --- keras/src/backend/config.py | 2 +- requirements-jax-cuda.txt | 2 +- requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index f3ee55ae077c..177395b6977b 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -244,7 +244,7 @@ def is_nnx_backend_enabled(): return _NNX_ENABLED -def set_nnx_enabled(value: bool): +def set_nnx_enabled(value): global _NNX_ENABLED from keras.src.backend.common import global_state diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 32959359dd7d..765263e82696 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,5 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax +flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index e5a44501e6b4..730f1fb2601c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax +flax>=0.10.1 # Common deps. -r requirements-common.txt From 75f9cc886036ff6aaa5f3ba0c64c6fc828b7bbc1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:50:01 +0000 Subject: [PATCH 052/103] ugh the jax version issue --- integration_tests/import_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index e2cd5484ca68..ec0e995a9ccf 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -11,7 +11,9 @@ "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), - "jax": ("jax[cpu]==0.5.0", ""), + # please update the jax version here if jax version is updated in + # requirements file + "jax": ("jax[cpu]==0.6.0", ""), "openvino": ("openvino", ""), } From 68261d465ffa537cefd2bf56684e00c107ea2ff1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 00:56:18 +0000 Subject: [PATCH 053/103] update jax version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 730f1fb2601c..ee61cb3fe564 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Jax. # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. -jax[cpu]==0.5.0 +jax[cpu]==0.6.0 flax>=0.10.1 # Common deps. -r requirements-common.txt From 1e09246c106f69a7af24076dd51d7766d92f8bb2 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 01:04:05 +0000 Subject: [PATCH 054/103] update installations --- integration_tests/import_test.py | 4 ++-- requirements-common.txt | 1 + requirements.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index ec0e995a9ccf..15107e2b4216 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -11,9 +11,9 @@ "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), - # please update the jax version here if jax version is updated in + # please update the jax version here if jax version is updated in # requirements file - "jax": ("jax[cpu]==0.6.0", ""), + "jax": ("jax[cpu]==0.5.0", ""), "openvino": ("openvino", ""), } diff --git a/requirements-common.txt b/requirements-common.txt index 7edc40c97a1a..21ec0efe7cdd 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,3 +24,4 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino +flax>=0.10.1 diff --git a/requirements.txt b/requirements.txt index ee61cb3fe564..730f1fb2601c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Jax. # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. -jax[cpu]==0.6.0 +jax[cpu]==0.5.0 flax>=0.10.1 # Common deps. -r requirements-common.txt From 8a142a1a5480305b0b153bf423e15521542a3ea1 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 01:36:52 +0000 Subject: [PATCH 055/103] update jax utils --- keras/src/utils/jax_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index 56b164d18125..2c6e757fb1f8 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -13,14 +13,15 @@ def is_in_jax_tracing_scope(x=None): def jit(*args, **kwargs): - import jax - def decorator(func): if is_nnx_backend_enabled(): from flax import nnx return nnx.jit(func, *args, **kwargs) else: - return jax.jit(func, *args, **kwargs) + if backend.backend() == "jax": + import jax + + return jax.jit(func, *args, **kwargs) return decorator From c7b2347b224477023e66f87db7798e9b287b61de Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 01:46:05 +0000 Subject: [PATCH 056/103] another requirents file fix --- requirements-common.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index 21ec0efe7cdd..7edc40c97a1a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,4 +24,3 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino -flax>=0.10.1 From 99d4307148db224c2ce39a18fa299e8f49400d4b Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 03:11:38 +0000 Subject: [PATCH 057/103] fix test --- keras/src/ops/operation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index bf42fae895dd..6f22649f46cb 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -1,8 +1,6 @@ import inspect import textwrap -from flax import nnx - from keras.src import backend from keras.src import dtype_policies from keras.src import tree @@ -125,6 +123,8 @@ def __new__(cls, *args, **kwargs): """ instance = super(Operation, cls).__new__(cls) if backend.backend() == "jax" and is_nnx_backend_enabled(): + from flax import nnx + vars(instance)["_object__state"] = nnx.object.ObjectState() # Generate a config to be returned by default by `get_config()`. arg_names = inspect.getfullargspec(cls.__init__).args From d544a0bb5830d880bf13b484b20ea2558207e74b Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 12 Jun 2025 03:23:14 +0000 Subject: [PATCH 058/103] add back flax to req common --- requirements-common.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-common.txt b/requirements-common.txt index 7edc40c97a1a..21ec0efe7cdd 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,3 +24,4 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino +flax>=0.10.1 From 3b8d90b21ec43051dbf24b9a6756f6692997dc8d Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 13 Jun 2025 23:15:59 +0000 Subject: [PATCH 059/103] address review comments --- guides/distributed_training_with_jax.py | 4 +- keras/api/_tf_keras/keras/config/__init__.py | 4 +- keras/api/config/__init__.py | 4 +- keras/src/backend/config.py | 4 +- keras/src/backend/jax/__init__.py | 4 +- keras/src/backend/jax/core.py | 89 +++++++------------- keras/src/backend/jax/core_test.py | 6 +- keras/src/backend/jax/trainer.py | 2 +- keras/src/layers/layer.py | 14 ++- keras/src/ops/operation.py | 4 +- keras/src/utils/jax_utils.py | 17 ++-- 11 files changed, 65 insertions(+), 87 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 9ce5cf42f948..adefbcfc58f5 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -53,7 +53,7 @@ from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled from keras.src.utils.jax_utils import jit from flax import nnx @@ -189,7 +189,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): # Training step, Keras provides a pure functional optimizer.stateless_apply -@jit() +@jit def train_step(train_state, x, y): ( trainable_variables, diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py index 65e32dd7f4ee..8cf3a1c30abd 100644 --- a/keras/api/_tf_keras/keras/config/__init__.py +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -17,9 +17,7 @@ from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) -from keras.src.backend.config import ( - is_nnx_backend_enabled as is_nnx_backend_enabled, -) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py index 65e32dd7f4ee..8cf3a1c30abd 100644 --- a/keras/api/config/__init__.py +++ b/keras/api/config/__init__.py @@ -17,9 +17,7 @@ from keras.src.backend.config import ( is_flash_attention_enabled as is_flash_attention_enabled, ) -from keras.src.backend.config import ( - is_nnx_backend_enabled as is_nnx_backend_enabled, -) +from keras.src.backend.config import is_nnx_enabled as is_nnx_enabled from keras.src.backend.config import max_epochs as max_epochs from keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch from keras.src.backend.config import set_epsilon as set_epsilon diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 177395b6977b..b33607bc0ff7 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -233,8 +233,8 @@ def is_flash_attention_enabled(): return global_state.get_global_attribute("flash_attention", default=None) -@keras_export("keras.config.is_nnx_backend_enabled") -def is_nnx_backend_enabled(): +@keras_export("keras.config.is_nnx_enabled") +def is_nnx_enabled(): """Checks whether NNX specific features are enabled for the JAX backend. Returns: diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 859d5e3a1d8d..335eed660b46 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,4 +1,4 @@ -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image @@ -12,7 +12,7 @@ from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS -if is_nnx_backend_enabled(): +if is_nnx_enabled(): from keras.src.backend.jax.core import NnxVariable as Variable else: from keras.src.backend.jax.core import JaxVariable as Variable diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index e36d33e5b1f0..ac83037a1c28 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -62,7 +62,7 @@ def __jax_array__(self): if config.is_nnx_backend_enabled(): from flax import nnx - class NnxVariable(KerasVariable, nnx.Variable): + class NnxVariable(JaxVariable, nnx.Variable): def __init__( self, initializer, @@ -126,8 +126,9 @@ def __init__( # self._initialize, which uses self._layout. object.__setattr__(self, "_layout", layout) - # Initialize KerasVariable. - super(NnxVariable, self).__init__( + # Initialize JaxVariable (which will call KerasVariable.__init__). + JaxVariable.__init__( + self, initializer=initializer, shape=shape, dtype=dtype, @@ -152,10 +153,13 @@ def __getstate__(self): # Get the state from KerasVariable (attributes in __dict__) # KerasVariable does not have a custom __getstate__, so we mimic # default behavior. - keras_state = self.__dict__.copy() + try: + keras_state = KerasVariable.__getstate__(self) + except AttributeError: + keras_state = object.__getstate__(self) # Get the state from nnx.Variable - nnx_specific_state = super(KerasVariable, self).__getstate__() + nnx_specific_state = nnx.Variable.__getstate__(self) # Merge them. Keras state is primary. NNX specific state adds # to it. @@ -202,7 +206,7 @@ def __setstate__(self, state): # Ensure Keras's self._value is also consistent with the # restored raw_value - object.__setattr__(self, "_value", nnx_raw_value) + self._value = nnx_raw_value if hasattr(self, "_shape") and self._shape is not None: self._ndim = len(self._shape) @@ -210,30 +214,12 @@ def __setstate__(self, state): # Fallback if shape isn't immediately available. self._ndim = len(self.raw_value.shape) - def _initialize(self, value): - # Note that variable.shape is needed by distribution_lib - self._shape = self._validate_shape(value.shape) - # We can't import the keras/distribution/distribution_lib - # due to circular dependency. - distribution = global_state.get_global_attribute("distribution") - if self._layout is None and distribution is not None: - tensor_layout = distribution.get_variable_layout(self) - from keras.src.distribution import TensorLayout - - if isinstance(tensor_layout, TensorLayout): - self._layout = tensor_layout.backend_layout - else: - self._layout = tensor_layout - self._direct_assign(value) - def _direct_assign(self, value): # Apply JAX-specific distribution if layout is present if self._layout is not None: - processed_value = distribution_lib.distribute_variable( + value = distribution_lib.distribute_variable( value, self._layout ) - else: - processed_value = value # Ensure that nnx.Variable part is initialized if not hasattr(self, "_var_metadata"): @@ -245,48 +231,35 @@ def _direct_assign(self, value): hasattr(self, "_var_metadata") and "on_set_value" in self._var_metadata ): - final_value = self._var_metadata["on_set_value"]( - self, processed_value - ) - else: - final_value = processed_value + value = self._var_metadata["on_set_value"](self, value) # Directly set raw_value. nnx.Variable handles mutable array # updates - object.__setattr__(self, "raw_value", final_value) - - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype, sparse=False) - - # Overload native accessor. - def __jax_array__(self): - return self.value + object.__setattr__(self, "raw_value", value) @property def value(self): if not hasattr(self, "raw_value"): - if not hasattr(self, "_value") or self._value is None: - if self._initializer is not None: - initial_value = self._initializer( - self._shape, dtype=self._dtype - ) - return self._maybe_autocast(initial_value) - else: - raise AttributeError( - "Variable is not properly initialized and has" - " no initializer." - ) - current_value = self._value - else: - current_value = self.raw_value - if ( - hasattr(self, "_var_metadata") - and "on_get_value" in self._var_metadata - ): - current_value = self._var_metadata["on_get_value"]( - self, current_value + if self._initializer is not None: + self._initialize( + self._initializer(self.shape, dtype=self.dtype) + ) + else: + raise AttributeError( + "Variable is not properly initialized (raw_value " + "missing) and has no initializer." ) + current_value = self.raw_value + + if ( + hasattr(self, "_var_metadata") + and "on_get_value" in self._var_metadata + ): + current_value = self._var_metadata["on_get_value"]( + self, current_value + ) + if in_stateless_scope(): scope = get_stateless_scope() stateless_value = scope.get_current_value(self) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 77c1727c788e..0578c97f4964 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -8,9 +8,9 @@ import keras from keras.src import backend from keras.src import testing -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled -if is_nnx_backend_enabled(): +if is_nnx_enabled(): from flax import nnx from keras.src.backend.jax.core import NnxVariable @@ -21,7 +21,7 @@ reason="JAX backend specific test for core Variable integration with NNX.", ) @pytest.mark.skipif( - not is_nnx_backend_enabled(), + not is_nnx_enabled(), reason="Test requires NNX backend to be enabled by default for setup.", ) class JaxCoreVariableTest(testing.TestCase): diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 975564fd0909..29f8b4d6ac7e 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -234,7 +234,7 @@ def concatenate(outputs): return output if not self.run_eagerly and self.jit_compile: - concatenate = jit()(concatenate) + concatenate = jit(concatenate) def iterator_step(state, iterator): data = next(iterator) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 4605318fcac8..22cad1a6a03b 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1664,8 +1664,20 @@ def get_config(self): return {**base_config, **config} def _open_name_scope(self): + from keras.src.utils import jax_utils # avoid circular imports + if self._parent_path is None: - self._parent_path = current_path() + # Avoid mutating _parent_path during a JAX trace if it's part of + # nnx.Object state and the object was created at a different trace + # level. We check if we are in NNX mode and if we are in a JAX + # trace. + if not ( + is_nnx_backend_enabled() and jax_utils.is_in_jax_tracing_scope() + ): + try: + self._parent_path = current_path() + except Exception: + pass return backend.name_scope(self.name, caller=self) def rematerialized_call(self, layer_call, *args, **kwargs): diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index 6f22649f46cb..aaf052a11588 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -6,7 +6,7 @@ from keras.src import tree from keras.src.api_export import keras_export from keras.src.backend.common.keras_tensor import any_symbolic_tensors -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled from keras.src.ops.node import Node from keras.src.utils import python_utils from keras.src.utils import traceback_utils @@ -122,7 +122,7 @@ def __new__(cls, *args, **kwargs): to manually implement `get_config()`. """ instance = super(Operation, cls).__new__(cls) - if backend.backend() == "jax" and is_nnx_backend_enabled(): + if backend.backend() == "jax" and is_nnx_enabled(): from flax import nnx vars(instance)["_object__state"] = nnx.object.ObjectState() diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index 2c6e757fb1f8..d0fb29a45d02 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -13,15 +13,12 @@ def is_in_jax_tracing_scope(x=None): def jit(*args, **kwargs): - def decorator(func): - if is_nnx_backend_enabled(): - from flax import nnx + if is_nnx_backend_enabled(): + from flax import nnx - return nnx.jit(func, *args, **kwargs) - else: - if backend.backend() == "jax": - import jax + return nnx.jit + else: + if backend.backend() == "jax": + import jax - return jax.jit(func, *args, **kwargs) - - return decorator + return jax.jit From 0e0fcd1888bce69f9fb7f3dc34f18a7591ba1a15 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 13 Jun 2025 23:25:12 +0000 Subject: [PATCH 060/103] fix tests --- keras/src/backend/jax/core.py | 4 +++- keras/src/layers/layer.py | 10 ++++------ keras/src/utils/jax_utils.py | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index ac83037a1c28..1886b1ae0564 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -59,7 +59,7 @@ def __jax_array__(self): _JAX_VARIABLE_TYPE = JaxVariable -if config.is_nnx_backend_enabled(): +if config.is_nnx_enabled(): from flax import nnx class NnxVariable(JaxVariable, nnx.Variable): @@ -268,6 +268,8 @@ def value(self): return self._maybe_autocast(current_value) + # Todo: NNX has agreed to fix it on thier end. I will remove it once + # that is done def __hash__(self): return id(self) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 22cad1a6a03b..4ca56dd24262 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -38,7 +38,7 @@ from keras.src.backend.common.name_scope import current_path from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec @@ -54,7 +54,7 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - if is_nnx_backend_enabled(): + if is_nnx_enabled(): from keras.src.backend.jax.layer import NnxLayer as BackendLayer else: from keras.src.backend.jax.layer import JaxLayer as BackendLayer @@ -1543,7 +1543,7 @@ def __setattr__(self, name, value): # NNX-specific bypass for `_called` and `built` attributes if ( backend.backend() == "jax" - and is_nnx_backend_enabled() + and is_nnx_enabled() and (name == "_called" or name == "built") ): object.__setattr__(self, name, value) @@ -1671,9 +1671,7 @@ def _open_name_scope(self): # nnx.Object state and the object was created at a different trace # level. We check if we are in NNX mode and if we are in a JAX # trace. - if not ( - is_nnx_backend_enabled() and jax_utils.is_in_jax_tracing_scope() - ): + if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): try: self._parent_path = current_path() except Exception: diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index d0fb29a45d02..3cdf24c5a19c 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -1,5 +1,5 @@ from keras.src import backend -from keras.src.backend.config import is_nnx_backend_enabled +from keras.src.backend.config import is_nnx_enabled def is_in_jax_tracing_scope(x=None): @@ -13,7 +13,7 @@ def is_in_jax_tracing_scope(x=None): def jit(*args, **kwargs): - if is_nnx_backend_enabled(): + if is_nnx_enabled(): from flax import nnx return nnx.jit From bd66ec8d037788017dd2474b264cd9acb72d1eca Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 13 Jun 2025 23:50:30 +0000 Subject: [PATCH 061/103] fix tests address more comments --- keras/src/backend/jax/core.py | 52 ++++++-------------------------- keras/src/backend/jax/layer.py | 3 +- keras/src/backend/jax/trainer.py | 4 +-- keras/src/layers/layer.py | 3 -- keras/src/ops/operation.py | 3 -- 5 files changed, 13 insertions(+), 52 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 1886b1ae0564..1bf75ca45374 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -77,43 +77,15 @@ def __init__( mutable=None, **nnx_metadata, ): - # Determine NNX mutability. This needs to be known for - # nnx.Variable.__init__. - if mutable is None: - actual_nnx_mutable = ( - trainable # Keras 'trainable' maps to NNX 'mutable' - ) - else: - actual_nnx_mutable = mutable - # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable' # param takes precedence. - if "mutable" in nnx_metadata and mutable is not None: - nnx_metadata["mutable"] = actual_nnx_mutable - elif "mutable" not in nnx_metadata: - nnx_metadata["mutable"] = actual_nnx_mutable + nnx_metadata["mutable"] = trainable if mutable is None else mutable # Initialize nnx.Variable first. # Determine the dtype for the placeholder. - _placeholder_value = None - if shape is not None: - if dtype is not None: - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(dtype) - ) - else: - _placeholder_value = jnp.zeros( - shape, dtype=standardize_dtype(config.floatx()) - ) - else: - if dtype is not None: - _placeholder_value = jnp.array( - 0.0, dtype=standardize_dtype(dtype) - ) - else: - _placeholder_value = jnp.array( - 0.0, dtype=standardize_dtype(config.floatx()) - ) + _placeholder_value = jnp.zeros( + shape or (), dtype=standardize_dtype(dtype) + ) # Call nnx.Variable.__init__ directly. nnx.Variable.__init__( @@ -239,6 +211,11 @@ def _direct_assign(self, value): @property def value(self): + if in_stateless_scope(): + scope = get_stateless_scope() + stateless_value = scope.get_current_value(self) + if stateless_value is not None: + return self._maybe_autocast(stateless_value) if not hasattr(self, "raw_value"): if self._initializer is not None: self._initialize( @@ -249,9 +226,7 @@ def value(self): "Variable is not properly initialized (raw_value " "missing) and has no initializer." ) - current_value = self.raw_value - if ( hasattr(self, "_var_metadata") and "on_get_value" in self._var_metadata @@ -259,16 +234,9 @@ def value(self): current_value = self._var_metadata["on_get_value"]( self, current_value ) - - if in_stateless_scope(): - scope = get_stateless_scope() - stateless_value = scope.get_current_value(self) - if stateless_value is not None: - return self._maybe_autocast(stateless_value) - return self._maybe_autocast(current_value) - # Todo: NNX has agreed to fix it on thier end. I will remove it once + # Todo: NNX has agreed to fix it on their end. I will remove it once # that is done def __hash__(self): return id(self) diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 3d7fb431b013..c14f959b146b 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -6,5 +6,4 @@ class JaxLayer: class NnxLayer(nnx.Module): - def __init_subclass__(cls): - super().__init_subclass__() + pass diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 29f8b4d6ac7e..a093f0c77592 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -278,7 +278,7 @@ def make_train_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - train_step = jit(donate_argnums=0)(self.train_step) + train_step = jit(self.train_step, donate_argnums=0) else: train_step = self.train_step @@ -294,7 +294,7 @@ def make_test_function(self, force=False): # so that jax will reuse the memory buffer for outputs. # This will reduce the memory usage of the training function by # half. - test_step = jit(donate_argnums=0)(self.test_step) + test_step = jit(self.test_step, donate_argnums=0) else: test_step = self.test_step diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 4ca56dd24262..280a99506acf 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -222,9 +222,6 @@ def call(self, inputs): ``` """ - def __init_subclass__(cls): - super().__init_subclass__() - def __new__(cls, *args, **kwargs): obj = super().__new__(cls, *args, **kwargs) # Wrap the user-provided `build` method in the `build_wrapper` diff --git a/keras/src/ops/operation.py b/keras/src/ops/operation.py index aaf052a11588..3b934761c866 100644 --- a/keras/src/ops/operation.py +++ b/keras/src/ops/operation.py @@ -15,9 +15,6 @@ @keras_export("keras.Operation") class Operation: - def __init_subclass__(cls): - super().__init_subclass__() - def __init__(self, name=None): if name is None: name = auto_name(self.__class__.__name__) From 8637c18c44e3fed3b0af9c946f49ff6e09b31079 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 13 Jun 2025 23:57:12 +0000 Subject: [PATCH 062/103] fix tests --- keras/src/backend/jax/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index a093f0c77592..8e60ebfa167a 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -312,7 +312,7 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jit(donate_argnums=0)(predict_step) + predict_step = jit(predict_step, donate_argnums=0) _step_function = self._make_function( predict_step, concatenate_outputs=True From 97d7371ccc0b659a3d28ff4ce1476a3900797af9 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 21:29:22 +0000 Subject: [PATCH 063/103] fix tests --- guides/distributed_training_with_jax.py | 2 -- keras/src/backend/jax/layer.py | 10 ++++++---- requirements-common.txt | 1 - 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index adefbcfc58f5..936b73c40950 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -53,9 +53,7 @@ from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P -from keras.src.backend.config import is_nnx_enabled from keras.src.utils.jax_utils import jit -from flax import nnx def get_model(): diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index c14f959b146b..7784bae431ed 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -1,9 +1,11 @@ -from flax import nnx # noqa F401 +from keras.src.backend.config import is_nnx_enabled +if is_nnx_enabled(): + from flax import nnx -class JaxLayer: - pass + class NnxLayer(nnx.Module): + pass -class NnxLayer(nnx.Module): +class JaxLayer: pass diff --git a/requirements-common.txt b/requirements-common.txt index 21ec0efe7cdd..7edc40c97a1a 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -24,4 +24,3 @@ coverage!=7.6.5 # 7.6.5 breaks CI # for onnx_test.py onnxruntime openvino -flax>=0.10.1 From b20321e178036f027be31bab56720a7cc9da186f Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 21:46:27 +0000 Subject: [PATCH 064/103] fix jax tests --- keras/src/utils/jax_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index 3cdf24c5a19c..a6933515c12c 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -12,13 +12,15 @@ def is_in_jax_tracing_scope(x=None): return False -def jit(*args, **kwargs): +def jit(func=None, *args, **kwargs): + jit_compiler = None if is_nnx_enabled(): from flax import nnx - return nnx.jit + jit_compiler = nnx.jit else: if backend.backend() == "jax": import jax - return jax.jit + jit_compiler = jax.jit + return jit_compiler(func, *args, **kwargs) From 46818b276dcc9a2c1692541c9779992a292792d3 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 21:49:58 +0000 Subject: [PATCH 065/103] revert guide --- guides/distributed_training_with_jax.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 936b73c40950..25b9e4f64727 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -53,7 +53,6 @@ from jax.sharding import Mesh from jax.sharding import NamedSharding from jax.sharding import PartitionSpec as P -from keras.src.utils.jax_utils import jit def get_model(): @@ -187,7 +186,7 @@ def compute_loss(trainable_variables, non_trainable_variables, x, y): # Training step, Keras provides a pure functional optimizer.stateless_apply -@jit +@jax.jit def train_step(train_state, x, y): ( trainable_variables, @@ -271,4 +270,4 @@ def get_replicated_train_state(devices): """ That's it! -""" +""" \ No newline at end of file From 9064df0a402ac4b9924113c4ad01e17d9b0f3d60 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 21:59:39 +0000 Subject: [PATCH 066/103] fix code format --- guides/distributed_training_with_jax.py | 2 +- integration_tests/import_test.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/guides/distributed_training_with_jax.py b/guides/distributed_training_with_jax.py index 25b9e4f64727..6f6dbbf25d78 100644 --- a/guides/distributed_training_with_jax.py +++ b/guides/distributed_training_with_jax.py @@ -270,4 +270,4 @@ def get_replicated_train_state(devices): """ That's it! -""" \ No newline at end of file +""" diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 15107e2b4216..27be33ae30df 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -15,6 +15,7 @@ # requirements file "jax": ("jax[cpu]==0.5.0", ""), "openvino": ("openvino", ""), + "nnx": ("flax>=0.10.1", ""), } From c127c2b249908e22ec64575f42d5cf3d35a92214 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 22:32:00 +0000 Subject: [PATCH 067/103] fix tests and jit --- integration_tests/import_test.py | 3 +-- keras/src/backend/jax/trainer.py | 9 ++++++++- keras/src/utils/jax_utils.py | 15 --------------- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 27be33ae30df..7715189bddc7 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -13,9 +13,8 @@ ), # please update the jax version here if jax version is updated in # requirements file - "jax": ("jax[cpu]==0.5.0", ""), + "jax": ("jax[cpu]==0.5.0flax>=0.10.0", ""), "openvino": ("openvino", ""), - "nnx": ("flax>=0.10.1", ""), } diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 8e60ebfa167a..327b7968f953 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -12,13 +12,20 @@ from keras.src import tree from keras.src.backend import config from keras.src.backend import distribution_lib as jax_distribution_lib +from keras.src.backend.config import is_nnx_enabled from keras.src.distribution import distribution_lib from keras.src.trainers import trainer as base_trainer from keras.src.trainers.data_adapters import array_slicing from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils -from keras.src.utils.jax_utils import jit + +if is_nnx_enabled(): + from flax import nnx + + jit = nnx.jit +else: + jit = jax.jit class JAXTrainer(base_trainer.Trainer): diff --git a/keras/src/utils/jax_utils.py b/keras/src/utils/jax_utils.py index a6933515c12c..d5375785f762 100644 --- a/keras/src/utils/jax_utils.py +++ b/keras/src/utils/jax_utils.py @@ -1,5 +1,4 @@ from keras.src import backend -from keras.src.backend.config import is_nnx_enabled def is_in_jax_tracing_scope(x=None): @@ -10,17 +9,3 @@ def is_in_jax_tracing_scope(x=None): if c.__name__ == "Tracer" and c.__module__.startswith("jax"): return True return False - - -def jit(func=None, *args, **kwargs): - jit_compiler = None - if is_nnx_enabled(): - from flax import nnx - - jit_compiler = nnx.jit - else: - if backend.backend() == "jax": - import jax - - jit_compiler = jax.jit - return jit_compiler(func, *args, **kwargs) From b81030b0d70f51a8defb2812b0ee34f0343566c2 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 16 Jun 2025 22:38:44 +0000 Subject: [PATCH 068/103] fix import test --- integration_tests/import_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 7715189bddc7..45d933a1e12d 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -13,7 +13,7 @@ ), # please update the jax version here if jax version is updated in # requirements file - "jax": ("jax[cpu]==0.5.0flax>=0.10.0", ""), + "jax": ("jax[cpu]==0.5.0 flax>=0.10.0", ""), "openvino": ("openvino", ""), } From f0b10ef6155b32060abd5f12af788155d552d939 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 18 Jun 2025 21:27:13 +0000 Subject: [PATCH 069/103] try to fix memory error --- keras/src/backend/jax/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 1bf75ca45374..12934aa60b53 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -83,7 +83,7 @@ def __init__( # Initialize nnx.Variable first. # Determine the dtype for the placeholder. - _placeholder_value = jnp.zeros( + _placeholder_value = jax.ShapeDtypeStruct( shape or (), dtype=standardize_dtype(dtype) ) From d18dd33b936aa4c1c78692f7aff980be8d17c35a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 20 Jun 2025 02:53:41 +0000 Subject: [PATCH 070/103] revert memory fix --- keras/src/backend/jax/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 12934aa60b53..1bf75ca45374 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -83,7 +83,7 @@ def __init__( # Initialize nnx.Variable first. # Determine the dtype for the placeholder. - _placeholder_value = jax.ShapeDtypeStruct( + _placeholder_value = jnp.zeros( shape or (), dtype=standardize_dtype(dtype) ) From 51aa455ddf451e433ed2700796c8846fea7383c6 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 20 Jun 2025 03:10:17 +0000 Subject: [PATCH 071/103] fix test --- keras/src/backend/common/variables.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index d40b67e06174..21f52be89dca 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -211,6 +211,12 @@ def __init__( def _deferred_initialize(self): if self._value is not None: + # If NNX is enabled, it's possible the variable was already + # initialized by a concrete call. In this case, + # _deferred_initialize becomes a no-op for this variable. + if config.is_nnx_enabled(): + self._initializer = None # Clear initializer as it's now "used" + return raise ValueError(f"Variable {self.path} is already initialized.") if in_stateless_scope(): From a02f410924876afa9ca901d491a9e77c15094627 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 20 Jun 2025 03:38:09 +0000 Subject: [PATCH 072/103] fix test --- integration_tests/basic_full_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index ae5c7a4c0449..66440ed6ab37 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -50,5 +50,5 @@ def test_basic_fit(self): def test_basic_fit_no_training(self): model = MyModel(hidden_dim=2, output_dim=1) x = np.random.random((128, 4)) - model.predict(x) model(x) + model.predict(x) From adca8da9d11fb485d7b0ec4969bcfd168988756e Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 20 Jun 2025 03:40:39 +0000 Subject: [PATCH 073/103] fix test --- integration_tests/import_test.py | 2 +- requirements-jax-cuda.txt | 2 +- requirements.txt | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 45d933a1e12d..951dcc27edf4 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -13,7 +13,7 @@ ), # please update the jax version here if jax version is updated in # requirements file - "jax": ("jax[cpu]==0.5.0 flax>=0.10.0", ""), + "jax": ("jax[cpu]==0.5.2 flax>=0.10.6", ""), "openvino": ("openvino", ""), } diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index 765263e82696..bc91ec59bbd6 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,5 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax>=0.10.1 +flax>=0.10.6 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 730f1fb2601c..70f8ccefaa2d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Jax. # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. -jax[cpu]==0.5.0 -flax>=0.10.1 +jax[cpu]==0.5.2 +flax>=0.10.6 # Common deps. -r requirements-common.txt From d8ca7524a2efb3292e469202f14c3115fb309586 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 20 Jun 2025 16:18:40 +0000 Subject: [PATCH 074/103] revert version back --- integration_tests/import_test.py | 2 +- requirements-jax-cuda.txt | 2 +- requirements.txt | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index 951dcc27edf4..c374e386483d 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -13,7 +13,7 @@ ), # please update the jax version here if jax version is updated in # requirements file - "jax": ("jax[cpu]==0.5.2 flax>=0.10.6", ""), + "jax": ("jax[cpu]==0.5.1 flax>=0.10.1", ""), "openvino": ("openvino", ""), } diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index bc91ec59bbd6..765263e82696 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -9,5 +9,5 @@ torch==2.6.0 # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax[cuda12]==0.6.0 -flax>=0.10.6 +flax>=0.10.1 -r requirements-common.txt diff --git a/requirements.txt b/requirements.txt index 70f8ccefaa2d..730f1fb2601c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Jax. # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. -jax[cpu]==0.5.2 -flax>=0.10.6 +jax[cpu]==0.5.0 +flax>=0.10.1 # Common deps. -r requirements-common.txt From a5741bea4cc7f35d84f30fd2beec612d88b6f84b Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 8 Jul 2025 19:58:35 -0700 Subject: [PATCH 075/103] Update functional.py --- keras/src/models/functional.py | 94 +++++++++++++++++++++++++++++++--- 1 file changed, 88 insertions(+), 6 deletions(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 72c781e83d81..72072cf3089f 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -7,6 +7,7 @@ from keras.src import ops from keras.src import tree from keras.src.backend.common import global_state +from keras.src.backend.config import is_nnx_enabled from keras.src.layers.core.input_layer import Input from keras.src.layers.core.input_layer import InputLayer from keras.src.layers.input_spec import InputSpec @@ -139,6 +140,11 @@ def __init__(self, inputs, outputs, name=None, **kwargs): self.trainable = trainable self._layers = self.layers + + # Special handling for NNX to ensure consistent layer instance usage + if is_nnx_enabled(): + self._setup_nnx_layer_mapping() + self.build(None) # We will convert directly (to the correct dtype per input). self._convert_input_args = False @@ -146,6 +152,27 @@ def __init__(self, inputs, outputs, name=None, **kwargs): output_layers = [x._keras_history[0] for x in self.outputs] self.output_names = [x.name for x in output_layers] + def _setup_nnx_layer_mapping(self): + """Setup layer mapping for NNX to ensure consistent layer instances.""" + # Create a mapping from operation id to layer instance + self._nnx_layer_mapping = {} + + # Store layers as direct attributes for NNX traversal + for i, layer in enumerate(self._layers): + if isinstance(layer, Layer): + # Store layer as direct attribute with unique name + attr_name = f"_layer_{i}_{layer.name}" + setattr(self, attr_name, layer) + # Map the operation id to this layer instance + self._nnx_layer_mapping[id(layer)] = layer + + # Also map any operations in the graph to ensure consistency + for operation in self._operations: + if isinstance(operation, Layer): + # Ensure the graph operation points to the same instance + if id(operation) not in self._nnx_layer_mapping: + self._nnx_layer_mapping[id(operation)] = operation + def _lock_state(self): # Unlike other layers, we allow Functional state to be mutable after # build. E.g. to attach a layer to a model that is not part of the @@ -180,14 +207,69 @@ def call(self, inputs, training=None, mask=None, **kwargs): for x, mask in zip(inputs, masks): if mask is not None: backend.set_keras_mask(x, mask) - outputs = self._run_through_graph( - inputs, - operation_fn=lambda op: operation_fn( - op, training=training, **kwargs - ), - ) + + # Use NNX-compatible execution when NNX is enabled + if is_nnx_enabled(): + outputs = self._run_through_graph_nnx_compatible( + inputs, + operation_fn=lambda op: operation_fn( + op, training=training, **kwargs + ), + ) + else: + outputs = self._run_through_graph( + inputs, + operation_fn=lambda op: operation_fn( + op, training=training, **kwargs + ), + ) return unpack_singleton(outputs) + def _run_through_graph_nnx_compatible(self, inputs, operation_fn, call_fn=None): + """NNX-compatible graph execution that ensures consistent layer instances.""" + inputs = tree.flatten(inputs) + + # Dictionary mapping reference tensors to computed tensors. + tensor_dict = {} + for x, y in zip(self.inputs, inputs): + tensor_dict[id(x)] = y + + nodes_by_depth = self._nodes_by_depth + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + for depth in depth_keys: + nodes = nodes_by_depth[depth] + for node in nodes: + if not node.operation or node.is_input: + continue # Input tensors already exist. + + if any(id(x) not in tensor_dict for x in node.input_tensors): + continue # Node is not computable, try skipping. + + args, kwargs = node.arguments.fill_in(tensor_dict) + + # Use the consistent layer instance for NNX compatibility + operation = node.operation + if hasattr(self, '_nnx_layer_mapping') and id(operation) in self._nnx_layer_mapping: + operation = self._nnx_layer_mapping[id(operation)] + + op = operation_fn(operation) + if call_fn is not None: + outputs = call_fn(op, *args, **kwargs) + else: + outputs = op(*args, **kwargs) + + # Update tensor_dict. + for x, y in zip(node.outputs, tree.flatten(outputs)): + tensor_dict[id(x)] = y + + output_tensors = [] + for x in self.outputs: + output_tensors.append(tensor_dict[id(x)]) + + return tree.pack_sequence_as(self._outputs_struct, output_tensors) + def compute_output_spec(self, inputs, training=None, mask=None): # From Function return super().compute_output_spec(inputs) From 02d607bbd021f16e9e032b19e1c24aa319021dac Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Tue, 8 Jul 2025 20:01:05 -0700 Subject: [PATCH 076/103] Update core.py --- keras/src/backend/jax/core.py | 45 ++++++++++++++++------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 1bf75ca45374..9d0db225557e 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -81,24 +81,21 @@ def __init__( # param takes precedence. nnx_metadata["mutable"] = trainable if mutable is None else mutable - # Initialize nnx.Variable first. - # Determine the dtype for the placeholder. - _placeholder_value = jnp.zeros( - shape or (), dtype=standardize_dtype(dtype) - ) - - # Call nnx.Variable.__init__ directly. - nnx.Variable.__init__( - self, value=_placeholder_value, **nnx_metadata - ) - - # Store JAX-specific layout using object.__setattr__ BEFORE - # KerasVariable init. - # This is because KerasVariable.__init__ will call - # self._initialize, which uses self._layout. - object.__setattr__(self, "_layout", layout) - - # Initialize JaxVariable (which will call KerasVariable.__init__). + # First, initialize a basic nnx.Variable with a dummy value + # This sets up the NNX variable structure + if shape is None: + dummy_value = jnp.array(0.0) + else: + dummy_value = jnp.zeros(shape, dtype=standardize_dtype(dtype)) + + # Initialize nnx.Variable first + nnx.Variable.__init__(self, value=dummy_value, **nnx_metadata) + + # Now we can safely set layout + self._layout = layout + + # Initialize JaxVariable (which will call KerasVariable.__init__ + # and set up the real value). JaxVariable.__init__( self, initializer=initializer, @@ -111,6 +108,9 @@ def __init__( name=name, ) + # The real value is now set in self._value, sync it to raw_value + object.__setattr__(self, "raw_value", self._value) + @property def _value(self): if hasattr(self, "raw_value"): @@ -193,11 +193,6 @@ def _direct_assign(self, value): value, self._layout ) - # Ensure that nnx.Variable part is initialized - if not hasattr(self, "_var_metadata"): - # todo: should add a warning - pass - # Apply on_set_value hook if it exists if ( hasattr(self, "_var_metadata") @@ -205,8 +200,8 @@ def _direct_assign(self, value): ): value = self._var_metadata["on_set_value"](self, value) - # Directly set raw_value. nnx.Variable handles mutable array - # updates + # Set the value for both Keras and NNX parts + # This ensures both systems see the same value object.__setattr__(self, "raw_value", value) @property From 68d0b6881e391709fa13a67931ca9e1ee059c334 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 9 Jul 2025 18:43:06 +0000 Subject: [PATCH 077/103] remove nnx workflow --- .github/workflows/actions.yml | 2 +- .github/workflows/config/nnx/keras.json | 7 ------- keras/src/backend/jax/core.py | 4 ++-- keras/src/models/functional.py | 23 ++++++++++++++--------- shell/lint.sh | 0 5 files changed, 17 insertions(+), 19 deletions(-) delete mode 100644 .github/workflows/config/nnx/keras.json create mode 100644 shell/lint.sh diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 16bdb759a9c0..b9e785dfc949 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -16,7 +16,7 @@ jobs: fail-fast: false matrix: python-version: ['3.10'] - backend: [tensorflow, jax, torch, numpy, openvino, nnx] + backend: [tensorflow, jax, torch, numpy, openvino] name: Run tests runs-on: ubuntu-latest env: diff --git a/.github/workflows/config/nnx/keras.json b/.github/workflows/config/nnx/keras.json deleted file mode 100644 index d6bb3e7fd4d5..000000000000 --- a/.github/workflows/config/nnx/keras.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "floatx": "float32", - "epsilon": 1e-07, - "backend": "jax", - "image_data_format": "channels_last", - "nnx_enabled": true -} diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 9d0db225557e..ccef9e986699 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -87,10 +87,10 @@ def __init__( dummy_value = jnp.array(0.0) else: dummy_value = jnp.zeros(shape, dtype=standardize_dtype(dtype)) - + # Initialize nnx.Variable first nnx.Variable.__init__(self, value=dummy_value, **nnx_metadata) - + # Now we can safely set layout self._layout = layout diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 72072cf3089f..e53c3ce0de0f 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -140,11 +140,11 @@ def __init__(self, inputs, outputs, name=None, **kwargs): self.trainable = trainable self._layers = self.layers - + # Special handling for NNX to ensure consistent layer instance usage if is_nnx_enabled(): self._setup_nnx_layer_mapping() - + self.build(None) # We will convert directly (to the correct dtype per input). self._convert_input_args = False @@ -156,7 +156,7 @@ def _setup_nnx_layer_mapping(self): """Setup layer mapping for NNX to ensure consistent layer instances.""" # Create a mapping from operation id to layer instance self._nnx_layer_mapping = {} - + # Store layers as direct attributes for NNX traversal for i, layer in enumerate(self._layers): if isinstance(layer, Layer): @@ -165,7 +165,7 @@ def _setup_nnx_layer_mapping(self): setattr(self, attr_name, layer) # Map the operation id to this layer instance self._nnx_layer_mapping[id(layer)] = layer - + # Also map any operations in the graph to ensure consistency for operation in self._operations: if isinstance(operation, Layer): @@ -207,7 +207,7 @@ def call(self, inputs, training=None, mask=None, **kwargs): for x, mask in zip(inputs, masks): if mask is not None: backend.set_keras_mask(x, mask) - + # Use NNX-compatible execution when NNX is enabled if is_nnx_enabled(): outputs = self._run_through_graph_nnx_compatible( @@ -225,7 +225,9 @@ def call(self, inputs, training=None, mask=None, **kwargs): ) return unpack_singleton(outputs) - def _run_through_graph_nnx_compatible(self, inputs, operation_fn, call_fn=None): + def _run_through_graph_nnx_compatible( + self, inputs, operation_fn, call_fn=None + ): """NNX-compatible graph execution that ensures consistent layer instances.""" inputs = tree.flatten(inputs) @@ -248,12 +250,15 @@ def _run_through_graph_nnx_compatible(self, inputs, operation_fn, call_fn=None): continue # Node is not computable, try skipping. args, kwargs = node.arguments.fill_in(tensor_dict) - + # Use the consistent layer instance for NNX compatibility operation = node.operation - if hasattr(self, '_nnx_layer_mapping') and id(operation) in self._nnx_layer_mapping: + if ( + hasattr(self, "_nnx_layer_mapping") + and id(operation) in self._nnx_layer_mapping + ): operation = self._nnx_layer_mapping[id(operation)] - + op = operation_fn(operation) if call_fn is not None: outputs = call_fn(op, *args, **kwargs) diff --git a/shell/lint.sh b/shell/lint.sh new file mode 100644 index 000000000000..e69de29bb2d1 From 12eb2a0940f75357eb2e797479ef6867b274395c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 9 Jul 2025 18:53:35 +0000 Subject: [PATCH 078/103] code reformat --- keras/src/models/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index e53c3ce0de0f..ddb1887b28dd 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -228,7 +228,7 @@ def call(self, inputs, training=None, mask=None, **kwargs): def _run_through_graph_nnx_compatible( self, inputs, operation_fn, call_fn=None ): - """NNX-compatible graph execution that ensures consistent layer instances.""" + """NNX-compatible graph execution ensures consistent layer instances.""" inputs = tree.flatten(inputs) # Dictionary mapping reference tensors to computed tensors. From a260cb48d9d10faf50768a74fb1aaebf47dca01e Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 11 Jul 2025 23:46:44 +0000 Subject: [PATCH 079/103] address gemini comments --- keras/src/backend/config.py | 4 ++-- keras/src/layers/layer.py | 8 ++++++-- keras/src/models/functional.py | 7 ------- 3 files changed, 8 insertions(+), 11 deletions(-) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index b33607bc0ff7..f57b2b2861f9 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -443,9 +443,9 @@ def max_steps_per_epoch(): if "KERAS_NNX_ENABLED" in os.environ: env_val = os.environ["KERAS_NNX_ENABLED"].lower() - if env_val == "true": + if env_val: _NNX_ENABLED = True - elif env_val == "false": + else: _NNX_ENABLED = False set_nnx_enabled(_NNX_ENABLED) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 6acedc355fc8..792e287346a6 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1675,8 +1675,12 @@ def _open_name_scope(self): if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): try: self._parent_path = current_path() - except Exception: - pass + except Exception as e: + warnings.warn( + "Could not set `_parent_path` in " + f"`_open_name_scope` for layer {self.name}. " + f"Error: {e}" + ) return backend.name_scope(self.name, caller=self) def rematerialized_call(self, layer_call, *args, **kwargs): diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index ddb1887b28dd..b49e8756db50 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -166,13 +166,6 @@ def _setup_nnx_layer_mapping(self): # Map the operation id to this layer instance self._nnx_layer_mapping[id(layer)] = layer - # Also map any operations in the graph to ensure consistency - for operation in self._operations: - if isinstance(operation, Layer): - # Ensure the graph operation points to the same instance - if id(operation) not in self._nnx_layer_mapping: - self._nnx_layer_mapping[id(operation)] = operation - def _lock_state(self): # Unlike other layers, we allow Functional state to be mutable after # build. E.g. to attach a layer to a model that is not part of the From c79a57f708f99dabc343f6dfc319e91b0719befc Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 19:52:46 +0000 Subject: [PATCH 080/103] address latest comments --- .github/workflows/config/nnx/keras.json | 7 +++ .github/workflows/nnx-tests.yml | 61 +++++++++++++++++++++++++ integration_tests/import_test.py | 32 ++++++++----- keras/src/backend/config.py | 47 ++++++++++--------- 4 files changed, 111 insertions(+), 36 deletions(-) create mode 100644 .github/workflows/config/nnx/keras.json create mode 100644 .github/workflows/nnx-tests.yml diff --git a/.github/workflows/config/nnx/keras.json b/.github/workflows/config/nnx/keras.json new file mode 100644 index 000000000000..263fba3763cb --- /dev/null +++ b/.github/workflows/config/nnx/keras.json @@ -0,0 +1,7 @@ +{ + "floatx": "float32", + "epsilon": 1e-07, + "backend": "jax", + "image_data_format": "channels_last", + "nnx_enabled": true +} \ No newline at end of file diff --git a/.github/workflows/nnx-tests.yml b/.github/workflows/nnx-tests.yml new file mode 100644 index 000000000000..cbdab3372ec2 --- /dev/null +++ b/.github/workflows/nnx-tests.yml @@ -0,0 +1,61 @@ +name: NNX Tests + +on: + push: + branches: [ master ] + pull_request: + release: + types: [created] + +permissions: + contents: read + +jobs: + nnx-tests: + name: Run NNX tests + runs-on: ubuntu-latest + env: + PYTHON: '3.10' + KERAS_HOME: .github/workflows/config/nnx + KERAS_BACKEND: jax + KERAS_NNX_ENABLED: true + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip setuptools + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + + - name: pip cache + uses: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} + + - name: Install dependencies + run: | + pip install -r requirements.txt --progress-bar off --upgrade + pip install --upgrade git+https://github.com/divyashreepathihalli/flax.git@use-id + pip uninstall -y keras keras-nightly + pip install -e "." --progress-bar off --upgrade + + - name: Test basic flow with NNX + run: | + python integration_tests/import_test.py + python integration_tests/basic_full_flow.py + + - name: Codecov NNX + uses: codecov/codecov-action@v5 + with: + env_vars: PYTHON,KERAS_HOME,KERAS_BACKEND,KERAS_NNX_ENABLED + flags: keras-nnx + files: nnx-coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index c374e386483d..b9e660d95df5 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -3,6 +3,7 @@ import subprocess from keras.src import backend +from keras.src.backend import config # For torch, use index url to avoid installing nvidia drivers for the test. BACKEND_REQ = { @@ -11,9 +12,7 @@ "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), - # please update the jax version here if jax version is updated in - # requirements file - "jax": ("jax[cpu]==0.5.1 flax>=0.10.1", ""), + "jax": ("jax[cpu]", ""), "openvino": ("openvino", ""), } @@ -57,16 +56,25 @@ def manage_venv_installs(whl_path): "pip install " + backend_extra_url + backend_pkg, "pip install -r requirements-common.txt", "pip install pytest", - # Ensure other backends are uninstalled - "pip uninstall -y " - + BACKEND_REQ[other_backends[0]][0] - + " " - + BACKEND_REQ[other_backends[1]][0] - + " " - + BACKEND_REQ[other_backends[2]][0], - # Install `.whl` package - "pip install " + whl_path, ] + + # Install flax for JAX when NNX is enabled + if backend.backend() == "jax" and config.is_nnx_enabled(): + install_setup.append("pip install flax>=0.10.1") + + install_setup.extend( + [ + # Ensure other backends are uninstalled + "pip uninstall -y " + + BACKEND_REQ[other_backends[0]][0] + + " " + + BACKEND_REQ[other_backends[1]][0] + + " " + + BACKEND_REQ[other_backends[2]][0], + # Install `.whl` package + "pip install " + whl_path, + ] + ) run_commands_venv(install_setup) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index f57b2b2861f9..30760cba6af7 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -254,8 +254,7 @@ def set_nnx_enabled(value): from flax import nnx # noqa F401 except ImportError: raise ImportError( - "To use the NNX backend, you must install `flax`." - "Try: `pip install flax`" + "To use NNX with the JAX backend, you must install `flax`." ) global_state.set_global_attribute("nnx_enabled", bool(value)) @@ -292,6 +291,28 @@ def keras_home(): # Attempt to read Keras config file. _config_path = os.path.expanduser(os.path.join(_KERAS_DIR, "keras.json")) +# Save config file, if possible. +if not os.path.exists(_KERAS_DIR): + try: + os.makedirs(_KERAS_DIR) + except OSError: + # Except permission denied. + pass + +if not os.path.exists(_config_path): + _config_to_save = { + "floatx": _FLOATX, + "epsilon": _EPSILON, + "backend": _BACKEND, + "image_data_format": _IMAGE_DATA_FORMAT, + } + try: + with open(_config_path, "w") as f: + f.write(json.dumps(_config_to_save, indent=4)) + except IOError: + # Except permission denied. + pass + if os.path.exists(_config_path): try: with open(_config_path) as f: @@ -419,28 +440,6 @@ def max_steps_per_epoch(): return _MAX_STEPS_PER_EPOCH -if not os.path.exists(_KERAS_DIR): - try: - os.makedirs(_KERAS_DIR) - except OSError: - # Except permission denied and potential race conditions - pass - -if not os.path.exists(_config_path): - _config_to_save = { - "floatx": floatx(), - "epsilon": epsilon(), - "backend": _BACKEND, # Use the final _BACKEND value - "image_data_format": image_data_format(), - "nnx_enabled": _NNX_ENABLED, - } - try: - with open(_config_path, "w") as f: - f.write(json.dumps(_config_to_save, indent=4)) - except IOError: - # Except permission denied. - pass - if "KERAS_NNX_ENABLED" in os.environ: env_val = os.environ["KERAS_NNX_ENABLED"].lower() if env_val: From 57b42cb60ba543fc39494f7c79dd272d908e5f0c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 20:03:17 +0000 Subject: [PATCH 081/103] remove hash function --- keras/src/backend/jax/core.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index ccef9e986699..9e9595d373c8 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -231,11 +231,6 @@ def value(self): ) return self._maybe_autocast(current_value) - # Todo: NNX has agreed to fix it on their end. I will remove it once - # that is done - def __hash__(self): - return id(self) - _JAX_VARIABLE_TYPE = NnxVariable From 567f1204efae9d098da04f1977f1c4343424ecbb Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 20:14:55 +0000 Subject: [PATCH 082/103] update tests name --- .github/workflows/nnx-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/nnx-tests.yml b/.github/workflows/nnx-tests.yml index cbdab3372ec2..e2025abf8ee8 100644 --- a/.github/workflows/nnx-tests.yml +++ b/.github/workflows/nnx-tests.yml @@ -1,4 +1,4 @@ -name: NNX Tests +name: Tests on: push: @@ -12,7 +12,7 @@ permissions: jobs: nnx-tests: - name: Run NNX tests + name: Run jax-nnx tests runs-on: ubuntu-latest env: PYTHON: '3.10' From 46db09bf441f4a7c73de48dafa9906ae5e669275 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 22:03:19 +0000 Subject: [PATCH 083/103] address latest comments --- integration_tests/basic_full_flow.py | 47 +++++++++++ keras/src/backend/config.py | 6 +- keras/src/layers/layer.py | 9 +- keras/src/models/functional.py | 119 +++++++++++---------------- shell/lint.sh | 0 5 files changed, 103 insertions(+), 78 deletions(-) delete mode 100644 shell/lint.sh diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index 66440ed6ab37..92ca3c3f1a07 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest @@ -52,3 +54,48 @@ def test_basic_fit_no_training(self): x = np.random.random((128, 4)) model(x) model.predict(x) + + @pytest.mark.skipif( + os.environ.get("KERAS_NNX_ENABLED") != "true", + reason="Test only runs with NNX enabled", + ) + def test_bare_ops_functional(self): + """Test that functional models work correctly with bare ops.""" + # Create input + inputs = keras.Input(shape=(10,)) + + # Add a layer + x = layers.Dense(5, activation="relu")(inputs) + + # Add a bare op (not a layer) + x = keras.ops.add(x, 2.0) + + # Add another layer + outputs = layers.Dense(1)(x) + + # Create functional model + model = keras.Model(inputs=inputs, outputs=outputs) + + # Test forward pass + test_input = np.random.random((2, 10)) + output = model(test_input) + + # Verify output shape and values + self.assertEqual(output.shape, (2, 1)) + self.assertTrue(np.all(np.isfinite(output))) + + # Test with multiple bare ops + inputs2 = keras.Input(shape=(5,)) + x2 = layers.Dense(3, activation="relu")(inputs2) + x2 = keras.ops.add(x2, 1.0) + x2 = keras.ops.multiply(x2, 2.0) + x2 = keras.ops.subtract(x2, 0.5) + outputs2 = layers.Dense(1)(x2) + + model2 = keras.Model(inputs=inputs2, outputs=outputs2) + test_input2 = np.random.random((3, 5)) + output2 = model2(test_input2) + + # Verify output shape and values + self.assertEqual(output2.shape, (3, 1)) + self.assertTrue(np.all(np.isfinite(output2))) diff --git a/keras/src/backend/config.py b/keras/src/backend/config.py index 30760cba6af7..516de51dc658 100644 --- a/keras/src/backend/config.py +++ b/keras/src/backend/config.py @@ -305,6 +305,7 @@ def keras_home(): "epsilon": _EPSILON, "backend": _BACKEND, "image_data_format": _IMAGE_DATA_FORMAT, + "nnx_enabled": is_nnx_enabled(), } try: with open(_config_path, "w") as f: @@ -327,10 +328,9 @@ def keras_home(): _image_data_format = _config.get("image_data_format", image_data_format()) assert _image_data_format in {"channels_last", "channels_first"} _nnx_enabled_config = _config.get("nnx_enabled", _NNX_ENABLED) - if not isinstance(_nnx_enabled_config, bool): - _NNX_ENABLED = str(_nnx_enabled_config).lower() == "true" - else: + if isinstance(_nnx_enabled_config, bool): _NNX_ENABLED = _nnx_enabled_config + # else: ignore non-bool values for nnx_enabled # Apply basic configs that don't cause circular import set_floatx(_floatx) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 1982ea18ae09..70a15b6236ba 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1680,11 +1680,12 @@ def _open_name_scope(self): if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): try: self._parent_path = current_path() - except Exception as e: + except Exception: warnings.warn( - "Could not set `_parent_path` in " - f"`_open_name_scope` for layer {self.name}. " - f"Error: {e}" + f"Layer '{self.name}' encountered an issue during " + "model construction. If you're experiencing unexpected " + "behavior, try calling your model on a small test " + "input first to ensure proper initialization." ) return backend.name_scope(self.name, caller=self) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index b49e8756db50..52eb9d291fa2 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -141,9 +141,9 @@ def __init__(self, inputs, outputs, name=None, **kwargs): self._layers = self.layers - # Special handling for NNX to ensure consistent layer instance usage + # Special handling for NNX to ensure consistent operation instance usage if is_nnx_enabled(): - self._setup_nnx_layer_mapping() + self._setup_nnx_op_mapping() self.build(None) # We will convert directly (to the correct dtype per input). @@ -152,19 +152,19 @@ def __init__(self, inputs, outputs, name=None, **kwargs): output_layers = [x._keras_history[0] for x in self.outputs] self.output_names = [x.name for x in output_layers] - def _setup_nnx_layer_mapping(self): - """Setup layer mapping for NNX to ensure consistent layer instances.""" - # Create a mapping from operation id to layer instance - self._nnx_layer_mapping = {} + def _setup_nnx_op_mapping(self): + """Setup operation mapping for NNX""" + # Create a mapping from operation id to operation instance + self._nnx_op_mapping = {} - # Store layers as direct attributes for NNX traversal - for i, layer in enumerate(self._layers): - if isinstance(layer, Layer): - # Store layer as direct attribute with unique name - attr_name = f"_layer_{i}_{layer.name}" - setattr(self, attr_name, layer) - # Map the operation id to this layer instance - self._nnx_layer_mapping[id(layer)] = layer + # Store operations as direct attributes for NNX traversal + for i, operation in enumerate(self._operations): + if isinstance(operation, Layer): + # Store operation as direct attribute with unique name + attr_name = f"_layer_{i}_{operation.name}" + setattr(self, attr_name, operation) + # Map the operation id to this operation instance + self._nnx_op_mapping[id(operation)] = operation def _lock_state(self): # Unlike other layers, we allow Functional state to be mutable after @@ -190,84 +190,61 @@ def layers(self, _): "Please use another name." ) - def call(self, inputs, training=None, mask=None, **kwargs): - # Add support for training, masking - inputs = self._standardize_inputs(inputs) - if mask is None: - masks = [None] * len(inputs) - else: - masks = tree.flatten(mask) - for x, mask in zip(inputs, masks): - if mask is not None: - backend.set_keras_mask(x, mask) - - # Use NNX-compatible execution when NNX is enabled - if is_nnx_enabled(): - outputs = self._run_through_graph_nnx_compatible( - inputs, - operation_fn=lambda op: operation_fn( - op, training=training, **kwargs - ), - ) - else: - outputs = self._run_through_graph( - inputs, - operation_fn=lambda op: operation_fn( - op, training=training, **kwargs - ), - ) - return unpack_singleton(outputs) + def _get_operation_for_node(self, node): + operation = node.operation + if hasattr(self, "_nnx_op_mapping") and id(operation) in getattr( + self, "_nnx_op_mapping", {} + ): + return self._nnx_op_mapping[id(operation)] + return operation - def _run_through_graph_nnx_compatible( - self, inputs, operation_fn, call_fn=None - ): - """NNX-compatible graph execution ensures consistent layer instances.""" + def _run_through_graph(self, inputs, operation_fn, call_fn=None): + """Unified graph execution that supports NNX layer mapping.""" inputs = tree.flatten(inputs) - - # Dictionary mapping reference tensors to computed tensors. tensor_dict = {} for x, y in zip(self.inputs, inputs): tensor_dict[id(x)] = y - nodes_by_depth = self._nodes_by_depth depth_keys = list(nodes_by_depth.keys()) depth_keys.sort(reverse=True) - for depth in depth_keys: nodes = nodes_by_depth[depth] for node in nodes: if not node.operation or node.is_input: - continue # Input tensors already exist. - + continue if any(id(x) not in tensor_dict for x in node.input_tensors): - continue # Node is not computable, try skipping. - + continue args, kwargs = node.arguments.fill_in(tensor_dict) - - # Use the consistent layer instance for NNX compatibility - operation = node.operation - if ( - hasattr(self, "_nnx_layer_mapping") - and id(operation) in self._nnx_layer_mapping - ): - operation = self._nnx_layer_mapping[id(operation)] - + operation = self._get_operation_for_node(node) op = operation_fn(operation) if call_fn is not None: outputs = call_fn(op, *args, **kwargs) else: outputs = op(*args, **kwargs) - - # Update tensor_dict. for x, y in zip(node.outputs, tree.flatten(outputs)): tensor_dict[id(x)] = y - output_tensors = [] for x in self.outputs: output_tensors.append(tensor_dict[id(x)]) - return tree.pack_sequence_as(self._outputs_struct, output_tensors) + def call(self, inputs, training=None, mask=None, **kwargs): + inputs = self._standardize_inputs(inputs) + if mask is None: + masks = [None] * len(inputs) + else: + masks = tree.flatten(mask) + for x, mask in zip(inputs, masks): + if mask is not None: + backend.set_keras_mask(x, mask) + outputs = self._run_through_graph( + inputs, + operation_fn=lambda op: operation_fn( + op, training=training, **kwargs + ), + ) + return unpack_singleton(outputs) + def compute_output_spec(self, inputs, training=None, mask=None): # From Function return super().compute_output_spec(inputs) @@ -473,7 +450,7 @@ def get_config(self): # the author of the subclassed network). return Model.get_config(self) - config = { + cfg = { "name": self.name, "trainable": self.trainable, } @@ -520,7 +497,7 @@ def get_config(self): layer_config["name"] = operation.name layer_config["inbound_nodes"] = filtered_inbound_nodes layer_configs.append(layer_config) - config["layers"] = layer_configs + cfg["layers"] = layer_configs # Gather info about inputs and outputs. def get_tensor_config(tensor): @@ -535,9 +512,9 @@ def get_tensor_config(tensor): def map_tensors(tensors): return tree.map_structure(get_tensor_config, tensors) - config["input_layers"] = map_tensors(self._inputs_struct) - config["output_layers"] = map_tensors(self._outputs_struct) - return copy.deepcopy(config) + cfg["input_layers"] = map_tensors(self._inputs_struct) + cfg["output_layers"] = map_tensors(self._outputs_struct) + return copy.deepcopy(cfg) def functional_from_config(cls, config, custom_objects=None): diff --git a/shell/lint.sh b/shell/lint.sh deleted file mode 100644 index e69de29bb2d1..000000000000 From c9b87b174967de5f87224774733febeae25f853c Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 22:42:36 +0000 Subject: [PATCH 084/103] address review comments --- .github/workflows/actions.yml | 22 ++++++++- .github/workflows/config/jax/keras.json | 3 +- .github/workflows/nnx-tests.yml | 61 ------------------------- requirements.txt | 2 +- 4 files changed, 23 insertions(+), 65 deletions(-) delete mode 100644 .github/workflows/nnx-tests.yml diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index b9e785dfc949..8a2428fe88ac 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -1,5 +1,8 @@ name: Tests +# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future +# Currently only basic flow tests run with NNX enabled + on: push: branches: [ master ] @@ -17,11 +20,18 @@ jobs: matrix: python-version: ['3.10'] backend: [tensorflow, jax, torch, numpy, openvino] + nnx_enabled: [false] + include: + - python-version: '3.10' + backend: jax + nnx_enabled: true name: Run tests runs-on: ubuntu-latest env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} + KERAS_BACKEND: ${{ matrix.backend }} + KERAS_NNX_ENABLED: ${{ matrix.nnx_enabled }} steps: - uses: actions/checkout@v4 - name: Check for changes in keras/src/applications @@ -48,6 +58,9 @@ jobs: - name: Install dependencies run: | pip install -r requirements.txt --progress-bar off --upgrade + if [ "${{ matrix.nnx_enabled }}" == "true" ]; then + pip install --upgrade git+https://github.com/divyashreepathihalli/flax.git@use-id + fi pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Test applications with pytest @@ -73,6 +86,11 @@ jobs: if: ${{ matrix.backend == 'jax'}} run: | python integration_tests/jax_custom_fit_test.py + - name: Test basic flow with NNX + if: ${{ matrix.nnx_enabled == 'true'}} + run: | + python integration_tests/import_test.py + python integration_tests/basic_full_flow.py - name: Test TF-specific integrations if: ${{ matrix.backend == 'tensorflow'}} run: | @@ -96,8 +114,8 @@ jobs: - name: Codecov keras uses: codecov/codecov-action@v5 with: - env_vars: PYTHON,KERAS_HOME - flags: keras,keras-${{ matrix.backend }} + env_vars: PYTHON,KERAS_HOME,KERAS_BACKEND,KERAS_NNX_ENABLED + flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }} files: core-coverage.xml token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false diff --git a/.github/workflows/config/jax/keras.json b/.github/workflows/config/jax/keras.json index 38b3a3207673..e20cd4ea7bfe 100644 --- a/.github/workflows/config/jax/keras.json +++ b/.github/workflows/config/jax/keras.json @@ -2,5 +2,6 @@ "floatx": "float32", "epsilon": 1e-07, "backend": "jax", - "image_data_format": "channels_last" + "image_data_format": "channels_last", + "nnx_enabled": false } diff --git a/.github/workflows/nnx-tests.yml b/.github/workflows/nnx-tests.yml deleted file mode 100644 index e2025abf8ee8..000000000000 --- a/.github/workflows/nnx-tests.yml +++ /dev/null @@ -1,61 +0,0 @@ -name: Tests - -on: - push: - branches: [ master ] - pull_request: - release: - types: [created] - -permissions: - contents: read - -jobs: - nnx-tests: - name: Run jax-nnx tests - runs-on: ubuntu-latest - env: - PYTHON: '3.10' - KERAS_HOME: .github/workflows/config/nnx - KERAS_BACKEND: jax - KERAS_NNX_ENABLED: true - steps: - - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Get pip cache dir - id: pip-cache - run: | - python -m pip install --upgrade pip setuptools - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - - name: pip cache - uses: actions/cache@v4 - with: - path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }} - - - name: Install dependencies - run: | - pip install -r requirements.txt --progress-bar off --upgrade - pip install --upgrade git+https://github.com/divyashreepathihalli/flax.git@use-id - pip uninstall -y keras keras-nightly - pip install -e "." --progress-bar off --upgrade - - - name: Test basic flow with NNX - run: | - python integration_tests/import_test.py - python integration_tests/basic_full_flow.py - - - name: Codecov NNX - uses: codecov/codecov-action@v5 - with: - env_vars: PYTHON,KERAS_HOME,KERAS_BACKEND,KERAS_NNX_ENABLED - flags: keras-nnx - files: nnx-coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false diff --git a/requirements.txt b/requirements.txt index 730f1fb2601c..e5a44501e6b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,6 @@ torch-xla==2.6.0;sys_platform != 'darwin' # Pinned to 0.5.0 on CPU. JAX 0.5.1 requires Tensorflow 2.19 for saved_model_test. # Note that we test against the latest JAX on GPU. jax[cpu]==0.5.0 -flax>=0.10.1 +flax # Common deps. -r requirements-common.txt From d4b5afa74589e57536dc2e48aebb2f3505578c35 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 22:53:24 +0000 Subject: [PATCH 085/103] fix actions --- .github/workflows/actions.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 8a2428fe88ac..c23f386dd7b3 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -25,13 +25,13 @@ jobs: - python-version: '3.10' backend: jax nnx_enabled: true - name: Run tests + name: ${{ matrix.backend == 'jax' && format('Run tests ({0}, {1}, nnx_enabled = {2})', matrix.python-version, matrix.backend, matrix.nnx_enabled) || format('Run tests ({0}, {1})', matrix.python-version, matrix.backend) }} runs-on: ubuntu-latest env: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} KERAS_BACKEND: ${{ matrix.backend }} - KERAS_NNX_ENABLED: ${{ matrix.nnx_enabled }} + KERAS_NNX_ENABLED: ${{ matrix.backend == 'jax' && matrix.nnx_enabled || 'false' }} steps: - uses: actions/checkout@v4 - name: Check for changes in keras/src/applications From 34fbeedb14b2030191e68328b895f138e221db68 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 23:05:53 +0000 Subject: [PATCH 086/103] fix actions --- .github/workflows/actions.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index c23f386dd7b3..0c7fd6e7ec8d 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -31,7 +31,6 @@ jobs: PYTHON: ${{ matrix.python-version }} KERAS_HOME: .github/workflows/config/${{ matrix.backend }} KERAS_BACKEND: ${{ matrix.backend }} - KERAS_NNX_ENABLED: ${{ matrix.backend == 'jax' && matrix.nnx_enabled || 'false' }} steps: - uses: actions/checkout@v4 - name: Check for changes in keras/src/applications @@ -88,6 +87,8 @@ jobs: python integration_tests/jax_custom_fit_test.py - name: Test basic flow with NNX if: ${{ matrix.nnx_enabled == 'true'}} + env: + KERAS_NNX_ENABLED: true run: | python integration_tests/import_test.py python integration_tests/basic_full_flow.py From 3683dc8f09f3e370657f688b70a01c511c922958 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 23:15:03 +0000 Subject: [PATCH 087/103] skipt other tests in nnx backend --- .github/workflows/actions.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 0c7fd6e7ec8d..07448b49ed83 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -82,7 +82,7 @@ jobs: python integration_tests/import_test.py python integration_tests/numerical_test.py - name: Test JAX-specific integrations - if: ${{ matrix.backend == 'jax'}} + if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled != 'true'}} run: | python integration_tests/jax_custom_fit_test.py - name: Test basic flow with NNX @@ -103,6 +103,7 @@ jobs: pytest integration_tests/torch_workflow_test.py python integration_tests/torch_custom_fit_test.py - name: Test with pytest + if: ${{ matrix.nnx_enabled != 'true'}} run: | if [ "${{ matrix.backend }}" == "openvino" ]; then IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt" From 53240f96fbe64aa00fdec6d8b766a4dfc452ba89 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 23:46:24 +0000 Subject: [PATCH 088/103] revert changes to basic flow --- integration_tests/basic_full_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index 92ca3c3f1a07..6985533b1f01 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -52,8 +52,8 @@ def test_basic_fit(self): def test_basic_fit_no_training(self): model = MyModel(hidden_dim=2, output_dim=1) x = np.random.random((128, 4)) - model(x) model.predict(x) + model(x) @pytest.mark.skipif( os.environ.get("KERAS_NNX_ENABLED") != "true", From 896ffa04c9a181f8de42e6090d5a59ed6c91876a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Mon, 14 Jul 2025 23:49:40 +0000 Subject: [PATCH 089/103] point installation to official JAX code --- .github/workflows/actions.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index 07448b49ed83..fee8785ec2bc 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -58,7 +58,7 @@ jobs: run: | pip install -r requirements.txt --progress-bar off --upgrade if [ "${{ matrix.nnx_enabled }}" == "true" ]; then - pip install --upgrade git+https://github.com/divyashreepathihalli/flax.git@use-id + pip install --upgrade git+https://github.com/google/flax.git fi pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade From 587bae78bf0a9e3de0c72af509d87b18d9bce76b Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 15 Jul 2025 00:27:42 +0000 Subject: [PATCH 090/103] fix actions --- .github/workflows/actions.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/actions.yml b/.github/workflows/actions.yml index fee8785ec2bc..49074e62aa1a 100644 --- a/.github/workflows/actions.yml +++ b/.github/workflows/actions.yml @@ -63,12 +63,12 @@ jobs: pip uninstall -y keras keras-nightly pip install -e "." --progress-bar off --upgrade - name: Test applications with pytest - if: ${{ steps.filter.outputs.applications == 'true' }} + if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }} run: | pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml coverage xml --include='keras/src/applications/*' -o apps-coverage.xml - name: Codecov keras.applications - if: ${{ steps.filter.outputs.applications == 'true' }} + if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }} uses: codecov/codecov-action@v5 with: env_vars: PYTHON,KERAS_HOME @@ -77,16 +77,16 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - name: Test integrations - if: ${{ matrix.backend != 'numpy'}} + if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }} run: | python integration_tests/import_test.py python integration_tests/numerical_test.py - name: Test JAX-specific integrations - if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled != 'true'}} + if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }} run: | python integration_tests/jax_custom_fit_test.py - name: Test basic flow with NNX - if: ${{ matrix.nnx_enabled == 'true'}} + if: ${{ matrix.nnx_enabled == true }} env: KERAS_NNX_ENABLED: true run: | @@ -103,7 +103,7 @@ jobs: pytest integration_tests/torch_workflow_test.py python integration_tests/torch_custom_fit_test.py - name: Test with pytest - if: ${{ matrix.nnx_enabled != 'true'}} + if: ${{ matrix.nnx_enabled == false }} run: | if [ "${{ matrix.backend }}" == "openvino" ]; then IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt" @@ -114,6 +114,7 @@ jobs: pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml - name: Codecov keras + if: ${{ matrix.nnx_enabled == false }} uses: codecov/codecov-action@v5 with: env_vars: PYTHON,KERAS_HOME,KERAS_BACKEND,KERAS_NNX_ENABLED From 3b4713cad8a7203f5192df4a0465043275d8680a Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 15 Jul 2025 01:02:27 +0000 Subject: [PATCH 091/103] revert basic flow test --- integration_tests/basic_full_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index 6985533b1f01..92ca3c3f1a07 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -52,8 +52,8 @@ def test_basic_fit(self): def test_basic_fit_no_training(self): model = MyModel(hidden_dim=2, output_dim=1) x = np.random.random((128, 4)) - model.predict(x) model(x) + model.predict(x) @pytest.mark.skipif( os.environ.get("KERAS_NNX_ENABLED") != "true", From 772929c78b34f87429987dbf37253530e30f2641 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Tue, 15 Jul 2025 22:18:31 +0000 Subject: [PATCH 092/103] move logic out of functional and into function class --- keras/src/models/functional.py | 43 ---------------------------------- keras/src/ops/function.py | 33 +++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 52eb9d291fa2..a33ddf5b41d3 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -7,7 +7,6 @@ from keras.src import ops from keras.src import tree from keras.src.backend.common import global_state -from keras.src.backend.config import is_nnx_enabled from keras.src.layers.core.input_layer import Input from keras.src.layers.core.input_layer import InputLayer from keras.src.layers.input_spec import InputSpec @@ -141,10 +140,6 @@ def __init__(self, inputs, outputs, name=None, **kwargs): self._layers = self.layers - # Special handling for NNX to ensure consistent operation instance usage - if is_nnx_enabled(): - self._setup_nnx_op_mapping() - self.build(None) # We will convert directly (to the correct dtype per input). self._convert_input_args = False @@ -190,44 +185,6 @@ def layers(self, _): "Please use another name." ) - def _get_operation_for_node(self, node): - operation = node.operation - if hasattr(self, "_nnx_op_mapping") and id(operation) in getattr( - self, "_nnx_op_mapping", {} - ): - return self._nnx_op_mapping[id(operation)] - return operation - - def _run_through_graph(self, inputs, operation_fn, call_fn=None): - """Unified graph execution that supports NNX layer mapping.""" - inputs = tree.flatten(inputs) - tensor_dict = {} - for x, y in zip(self.inputs, inputs): - tensor_dict[id(x)] = y - nodes_by_depth = self._nodes_by_depth - depth_keys = list(nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - for depth in depth_keys: - nodes = nodes_by_depth[depth] - for node in nodes: - if not node.operation or node.is_input: - continue - if any(id(x) not in tensor_dict for x in node.input_tensors): - continue - args, kwargs = node.arguments.fill_in(tensor_dict) - operation = self._get_operation_for_node(node) - op = operation_fn(operation) - if call_fn is not None: - outputs = call_fn(op, *args, **kwargs) - else: - outputs = op(*args, **kwargs) - for x, y in zip(node.outputs, tree.flatten(outputs)): - tensor_dict[id(x)] = y - output_tensors = [] - for x in self.outputs: - output_tensors.append(tensor_dict[id(x)]) - return tree.pack_sequence_as(self._outputs_struct, output_tensors) - def call(self, inputs, training=None, mask=None, **kwargs): inputs = self._standardize_inputs(inputs) if mask is None: diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index 18088cd3f5d9..d7663e6415e8 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -4,6 +4,7 @@ from keras.src.api_export import keras_export from keras.src.backend import KerasTensor from keras.src.backend.config import backend +from keras.src.backend.config import is_nnx_enabled from keras.src.ops.operation import Operation @@ -88,6 +89,10 @@ def __init__(self, inputs, outputs, name=None): ): raise ValueError("`inputs` not connected to `outputs`") + # Special handling for NNX to ensure consistent operation instance usage + if is_nnx_enabled(): + self._setup_nnx_op_mapping() + @property def operations(self): return self._operations[:] @@ -102,6 +107,28 @@ def outputs(self): """Flat list of the symbolic outputs of the Function.""" return self._outputs + def _setup_nnx_op_mapping(self): + """Setup operation mapping for NNX""" + # Create a mapping from operation id to operation instance + self._nnx_op_mapping = {} + + # Store operations as direct attributes for NNX traversal + for i, operation in enumerate(self._operations): + # Store operation as direct attribute with unique name + attr_name = f"_op_{i}_{operation.name}" + setattr(self, attr_name, operation) + # Map the operation id to this operation instance + self._nnx_op_mapping[id(operation)] = operation + + def _get_operation_for_node(self, node): + """Get the operation for a node, using NNX mapping if enabled.""" + operation = node.operation + if hasattr(self, "_nnx_op_mapping") and id(operation) in getattr( + self, "_nnx_op_mapping", {} + ): + return self._nnx_op_mapping[id(operation)] + return operation + def compute_output_spec(self, inputs): self._assert_input_compatibility(inputs) # Check if input shapes are identical to ref input shapes, @@ -170,10 +197,14 @@ def _run_through_graph(self, inputs, operation_fn, call_fn=None): continue # Node is not computable, try skipping. args, kwargs = node.arguments.fill_in(tensor_dict) - op = operation_fn(node.operation) if call_fn is not None: + # Use call_fn if provided (e.g., for symbolic execution) + op = operation_fn(node.operation) outputs = call_fn(op, *args, **kwargs) else: + # Use NNX operation mapping + operation = self._get_operation_for_node(node) + op = operation_fn(operation) outputs = op(*args, **kwargs) # Update tensor_dict. From f84cc1e7649a1ecab382adfdb4d20a971fb4711a Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Thu, 17 Jul 2025 10:03:13 -0700 Subject: [PATCH 093/103] revert functional.py --- keras/src/models/functional.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index a33ddf5b41d3..72c781e83d81 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -139,7 +139,6 @@ def __init__(self, inputs, outputs, name=None, **kwargs): self.trainable = trainable self._layers = self.layers - self.build(None) # We will convert directly (to the correct dtype per input). self._convert_input_args = False @@ -147,20 +146,6 @@ def __init__(self, inputs, outputs, name=None, **kwargs): output_layers = [x._keras_history[0] for x in self.outputs] self.output_names = [x.name for x in output_layers] - def _setup_nnx_op_mapping(self): - """Setup operation mapping for NNX""" - # Create a mapping from operation id to operation instance - self._nnx_op_mapping = {} - - # Store operations as direct attributes for NNX traversal - for i, operation in enumerate(self._operations): - if isinstance(operation, Layer): - # Store operation as direct attribute with unique name - attr_name = f"_layer_{i}_{operation.name}" - setattr(self, attr_name, operation) - # Map the operation id to this operation instance - self._nnx_op_mapping[id(operation)] = operation - def _lock_state(self): # Unlike other layers, we allow Functional state to be mutable after # build. E.g. to attach a layer to a model that is not part of the @@ -186,6 +171,7 @@ def layers(self, _): ) def call(self, inputs, training=None, mask=None, **kwargs): + # Add support for training, masking inputs = self._standardize_inputs(inputs) if mask is None: masks = [None] * len(inputs) @@ -407,7 +393,7 @@ def get_config(self): # the author of the subclassed network). return Model.get_config(self) - cfg = { + config = { "name": self.name, "trainable": self.trainable, } @@ -454,7 +440,7 @@ def get_config(self): layer_config["name"] = operation.name layer_config["inbound_nodes"] = filtered_inbound_nodes layer_configs.append(layer_config) - cfg["layers"] = layer_configs + config["layers"] = layer_configs # Gather info about inputs and outputs. def get_tensor_config(tensor): @@ -469,9 +455,9 @@ def get_tensor_config(tensor): def map_tensors(tensors): return tree.map_structure(get_tensor_config, tensors) - cfg["input_layers"] = map_tensors(self._inputs_struct) - cfg["output_layers"] = map_tensors(self._outputs_struct) - return copy.deepcopy(cfg) + config["input_layers"] = map_tensors(self._inputs_struct) + config["output_layers"] = map_tensors(self._outputs_struct) + return copy.deepcopy(config) def functional_from_config(cls, config, custom_objects=None): From 05d01196a44c39869b25d6d504db6769f5db1549 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Fri, 18 Jul 2025 21:31:14 +0000 Subject: [PATCH 094/103] simplify init --- keras/src/backend/jax/__init__.py | 6 +----- keras/src/backend/jax/core.py | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 335eed660b46..89ac0fa71c8c 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -11,11 +11,7 @@ from keras.src.backend.jax.core import IS_THREAD_SAFE from keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS from keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS - -if is_nnx_enabled(): - from keras.src.backend.jax.core import NnxVariable as Variable -else: - from keras.src.backend.jax.core import JaxVariable as Variable +from keras.src.backend.jax.core import Variable from keras.src.backend.jax.core import cast from keras.src.backend.jax.core import compute_output_spec from keras.src.backend.jax.core import cond diff --git a/keras/src/backend/jax/core.py b/keras/src/backend/jax/core.py index 9e9595d373c8..d43bda8f0bfe 100644 --- a/keras/src/backend/jax/core.py +++ b/keras/src/backend/jax/core.py @@ -58,7 +58,7 @@ def __jax_array__(self): return self.value -_JAX_VARIABLE_TYPE = JaxVariable +Variable = JaxVariable if config.is_nnx_enabled(): from flax import nnx @@ -231,7 +231,7 @@ def value(self): ) return self._maybe_autocast(current_value) - _JAX_VARIABLE_TYPE = NnxVariable + Variable = NnxVariable def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): @@ -247,7 +247,7 @@ def convert_to_tensor(x, dtype=None, sparse=None, ragged=None): # an existing distributed jax array will raise error. return x - if isinstance(x, _JAX_VARIABLE_TYPE): + if isinstance(x, Variable): if dtype is not None and x.dtype != dtype: return x.value.astype(dtype) return x.value @@ -531,7 +531,7 @@ def fori_loop(lower, upper, body_fun, init_val): def stop_gradient(variable): - if isinstance(variable, _JAX_VARIABLE_TYPE): + if isinstance(variable, Variable): variable = variable.value return jax.lax.stop_gradient(variable) From 7cceaae4e3e43d45d53e5d64e658d3481e68d7bb Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Sun, 20 Jul 2025 12:01:15 -0700 Subject: [PATCH 095/103] FIX MODEL BUILD ERROR --- keras/src/backend/jax/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index f835da474b7a..c5135d85af3c 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -647,8 +647,11 @@ def predict( x, _, _ = data_adapter_utils.unpack_x_y_sample_weight( next(iterator) ) - with backend.StatelessScope(): + if is_nnx_enabled(): self(x) + else: + with backend.StatelessScope(): + self(x) break epoch_iterator.reset() # Container that configures and calls callbacks. From 3af3c30fb945fc114126a01733496e4d0c255fa7 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Sun, 20 Jul 2025 12:09:09 -0700 Subject: [PATCH 096/103] revert changes to basic_full_flow.py --- integration_tests/basic_full_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index 92ca3c3f1a07..6985533b1f01 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -52,8 +52,8 @@ def test_basic_fit(self): def test_basic_fit_no_training(self): model = MyModel(hidden_dim=2, output_dim=1) x = np.random.random((128, 4)) - model(x) model.predict(x) + model(x) @pytest.mark.skipif( os.environ.get("KERAS_NNX_ENABLED") != "true", From 7494af613013bdc5e855751113beb2855751f9f3 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Sun, 20 Jul 2025 12:10:23 -0700 Subject: [PATCH 097/103] revert basic_full_flow.py --- integration_tests/basic_full_flow.py | 47 ---------------------------- 1 file changed, 47 deletions(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index 6985533b1f01..ae5c7a4c0449 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -1,5 +1,3 @@ -import os - import numpy as np import pytest @@ -54,48 +52,3 @@ def test_basic_fit_no_training(self): x = np.random.random((128, 4)) model.predict(x) model(x) - - @pytest.mark.skipif( - os.environ.get("KERAS_NNX_ENABLED") != "true", - reason="Test only runs with NNX enabled", - ) - def test_bare_ops_functional(self): - """Test that functional models work correctly with bare ops.""" - # Create input - inputs = keras.Input(shape=(10,)) - - # Add a layer - x = layers.Dense(5, activation="relu")(inputs) - - # Add a bare op (not a layer) - x = keras.ops.add(x, 2.0) - - # Add another layer - outputs = layers.Dense(1)(x) - - # Create functional model - model = keras.Model(inputs=inputs, outputs=outputs) - - # Test forward pass - test_input = np.random.random((2, 10)) - output = model(test_input) - - # Verify output shape and values - self.assertEqual(output.shape, (2, 1)) - self.assertTrue(np.all(np.isfinite(output))) - - # Test with multiple bare ops - inputs2 = keras.Input(shape=(5,)) - x2 = layers.Dense(3, activation="relu")(inputs2) - x2 = keras.ops.add(x2, 1.0) - x2 = keras.ops.multiply(x2, 2.0) - x2 = keras.ops.subtract(x2, 0.5) - outputs2 = layers.Dense(1)(x2) - - model2 = keras.Model(inputs=inputs2, outputs=outputs2) - test_input2 = np.random.random((3, 5)) - output2 = model2(test_input2) - - # Verify output shape and values - self.assertEqual(output2.shape, (3, 1)) - self.assertTrue(np.all(np.isfinite(output2))) From 403681b2da7ec99d5926883026a318beb90c9ed7 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 23 Jul 2025 19:08:48 -0700 Subject: [PATCH 098/103] address review comments --- integration_tests/basic_full_flow.py | 25 +++++++++++++++++++++++++ keras/src/backend/common/variables.py | 1 - keras/src/backend/jax/layer.py | 7 ++++--- keras/src/layers/layer.py | 10 +++------- 4 files changed, 32 insertions(+), 11 deletions(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index ae5c7a4c0449..a74fb81ea6e1 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest @@ -7,6 +9,7 @@ from keras.src import metrics from keras.src import optimizers from keras.src import testing +from keras.src.backend.common.variables import initialize_all_variables class MyModel(keras.Model): @@ -52,3 +55,25 @@ def test_basic_fit_no_training(self): x = np.random.random((128, 4)) model.predict(x) model(x) + + +def test_nnx_variable_initializer_bug(): + # Enable JAX + NNX backend + os.environ["KERAS_BACKEND"] = "jax" + os.environ["KERAS_NNX_ENABLED"] = "true" + import keras + + model = keras.Sequential([keras.layers.Dense(1, input_shape=(2,))]) + x = np.ones((1, 2)) + # First call: triggers tracing and variable initialization + model(x) + # Save the kernel value after first call + kernel_before = model.layers[0].kernel.value.copy() + # Now forcibly re-initialize all variables + initialize_all_variables() + # Check if the kernel value has changed + kernel_after = model.layers[0].kernel.value + assert np.allclose(kernel_before, kernel_after), ( + "Kernel was re-initialized! This is a bug if NNX is enabled and " + "the initializer was not cleared." + ) diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 21f52be89dca..d2df6d34d52f 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -215,7 +215,6 @@ def _deferred_initialize(self): # initialized by a concrete call. In this case, # _deferred_initialize becomes a no-op for this variable. if config.is_nnx_enabled(): - self._initializer = None # Clear initializer as it's now "used" return raise ValueError(f"Variable {self.path} is already initialized.") diff --git a/keras/src/backend/jax/layer.py b/keras/src/backend/jax/layer.py index 7784bae431ed..bb53f5f42011 100644 --- a/keras/src/backend/jax/layer.py +++ b/keras/src/backend/jax/layer.py @@ -3,9 +3,10 @@ if is_nnx_enabled(): from flax import nnx - class NnxLayer(nnx.Module): - pass + BaseLayer = nnx.Module +else: + BaseLayer = object -class JaxLayer: +class JaxLayer(BaseLayer): pass diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 70a15b6236ba..063baa8f5bdc 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -53,10 +53,7 @@ if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer elif backend.backend() == "jax": - if is_nnx_enabled(): - from keras.src.backend.jax.layer import NnxLayer as BackendLayer - else: - from keras.src.backend.jax.layer import JaxLayer as BackendLayer + from keras.src.backend.jax.layer import JaxLayer as BackendLayer elif backend.backend() == "torch": from keras.src.backend.torch.layer import TorchLayer as BackendLayer elif backend.backend() == "numpy": @@ -1547,6 +1544,7 @@ def __setattr__(self, name, value): value = self._tracker.track(value) # NNX-specific bypass for `_called` and `built` attributes + # bypass nnx.Module.__setattr__ which cannot be called while tracing if ( backend.backend() == "jax" and is_nnx_enabled() @@ -1555,9 +1553,7 @@ def __setattr__(self, name, value): object.__setattr__(self, name, value) return - super().__setattr__( - name, value - ) # Default path, including for NnxLayer -> nnx.Module + super().__setattr__(name, value) def __delattr__(self, name): obj = getattr(self, name) From 7a2ddd8ce5f144c0def80b2eb406eb9363a326ca Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 23 Jul 2025 19:09:06 -0700 Subject: [PATCH 099/103] add layer.py# modified: keras/src/layers/layer.py --- keras/src/layers/layer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 063baa8f5bdc..6dd945557603 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1674,15 +1674,8 @@ def _open_name_scope(self): # level. We check if we are in NNX mode and if we are in a JAX # trace. if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()): - try: - self._parent_path = current_path() - except Exception: - warnings.warn( - f"Layer '{self.name}' encountered an issue during " - "model construction. If you're experiencing unexpected " - "behavior, try calling your model on a small test " - "input first to ensure proper initialization." - ) + self._parent_path = current_path() + return backend.name_scope(self.name, caller=self) def rematerialized_call(self, layer_call, *args, **kwargs): From 534b9751d04e0cd8591c0d6f2a6dd45a1fbdeece Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 24 Jul 2025 02:18:58 +0000 Subject: [PATCH 100/103] revert variables change --- integration_tests/basic_full_flow.py | 25 ------------------------- keras/src/backend/common/variables.py | 2 -- 2 files changed, 27 deletions(-) diff --git a/integration_tests/basic_full_flow.py b/integration_tests/basic_full_flow.py index a74fb81ea6e1..ae5c7a4c0449 100644 --- a/integration_tests/basic_full_flow.py +++ b/integration_tests/basic_full_flow.py @@ -1,5 +1,3 @@ -import os - import numpy as np import pytest @@ -9,7 +7,6 @@ from keras.src import metrics from keras.src import optimizers from keras.src import testing -from keras.src.backend.common.variables import initialize_all_variables class MyModel(keras.Model): @@ -55,25 +52,3 @@ def test_basic_fit_no_training(self): x = np.random.random((128, 4)) model.predict(x) model(x) - - -def test_nnx_variable_initializer_bug(): - # Enable JAX + NNX backend - os.environ["KERAS_BACKEND"] = "jax" - os.environ["KERAS_NNX_ENABLED"] = "true" - import keras - - model = keras.Sequential([keras.layers.Dense(1, input_shape=(2,))]) - x = np.ones((1, 2)) - # First call: triggers tracing and variable initialization - model(x) - # Save the kernel value after first call - kernel_before = model.layers[0].kernel.value.copy() - # Now forcibly re-initialize all variables - initialize_all_variables() - # Check if the kernel value has changed - kernel_after = model.layers[0].kernel.value - assert np.allclose(kernel_before, kernel_after), ( - "Kernel was re-initialized! This is a bug if NNX is enabled and " - "the initializer was not cleared." - ) diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index d2df6d34d52f..88c0308278b0 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -214,8 +214,6 @@ def _deferred_initialize(self): # If NNX is enabled, it's possible the variable was already # initialized by a concrete call. In this case, # _deferred_initialize becomes a no-op for this variable. - if config.is_nnx_enabled(): - return raise ValueError(f"Variable {self.path} is already initialized.") if in_stateless_scope(): From 96f65af7bc71091f97b70f4a183d5abdc532d0c0 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 24 Jul 2025 20:17:39 +0000 Subject: [PATCH 101/103] nit --- keras/src/backend/jax/core_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/jax/core_test.py b/keras/src/backend/jax/core_test.py index 0578c97f4964..792cf25e67f0 100644 --- a/keras/src/backend/jax/core_test.py +++ b/keras/src/backend/jax/core_test.py @@ -24,7 +24,7 @@ not is_nnx_enabled(), reason="Test requires NNX backend to be enabled by default for setup.", ) -class JaxCoreVariableTest(testing.TestCase): +class NnxVariableTest(testing.TestCase): def setup(self): super().setup() From 8f2798ea6ab4b98f660585b6f85c2700a4ab8ccc Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 24 Jul 2025 20:18:39 +0000 Subject: [PATCH 102/103] update init --- keras/src/backend/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index a200b17c914e..15f1af2145d5 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -39,7 +39,7 @@ from keras.src.backend.tensorflow.core import Variable as BackendVariable elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 - from keras.src.backend.jax import Variable as BackendVariable + from keras.src.backend.jax.core import Variable as BackendVariable elif backend() == "torch": from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable From fd629cf084649e5a7d8dc55f7216aa39a8c33607 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Thu, 24 Jul 2025 21:51:44 +0000 Subject: [PATCH 103/103] assign all operations to one attribute --- keras/src/ops/function.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras/src/ops/function.py b/keras/src/ops/function.py index d7663e6415e8..4d04182f2470 100644 --- a/keras/src/ops/function.py +++ b/keras/src/ops/function.py @@ -112,11 +112,9 @@ def _setup_nnx_op_mapping(self): # Create a mapping from operation id to operation instance self._nnx_op_mapping = {} - # Store operations as direct attributes for NNX traversal - for i, operation in enumerate(self._operations): - # Store operation as direct attribute with unique name - attr_name = f"_op_{i}_{operation.name}" - setattr(self, attr_name, operation) + # Assign the list of operations to a single attribute for NNX traversal + self.nnx_operations = self._operations[:] + for operation in self._operations: # Map the operation id to this operation instance self._nnx_op_mapping[id(operation)] = operation