From 779ffcf5c5983615c26fee59c167991034dc3dcf Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Tue, 27 May 2025 00:20:12 +0000 Subject: [PATCH 1/2] example mlx workflow and adjusted tests for linux build --- examples/demo_custom_mlx_workflow.py | 122 ++++++++++++++++++ .../preprocessing/stft_spectrogram_test.py | 22 +++- keras/src/trainers/trainer_test.py | 5 +- 3 files changed, 143 insertions(+), 6 deletions(-) create mode 100644 examples/demo_custom_mlx_workflow.py diff --git a/examples/demo_custom_mlx_workflow.py b/examples/demo_custom_mlx_workflow.py new file mode 100644 index 000000000000..34c7c5b3f190 --- /dev/null +++ b/examples/demo_custom_mlx_workflow.py @@ -0,0 +1,122 @@ +import os + +# Set backend env to MLX +os.environ["KERAS_BACKEND"] = "mlx" + +import mlx.core as mx +import mlx.nn as nn + +from keras import Model +from keras import initializers +from keras import layers +from keras import ops +from keras import optimizers +from keras import Variable + + +class MyDense(layers.Layer): + def __init__(self, units, name=None): + super().__init__(name=name) + self.units = units + + def build(self, input_shape): + input_dim = input_shape[-1] + w_shape = (input_dim, self.units) + w_value = initializers.GlorotUniform()(w_shape) + self.w = Variable(w_value, name="kernel") + + b_shape = (self.units,) + b_value = initializers.Zeros()(b_shape) + self.b = Variable(b_value, name="bias") + + def call(self, inputs): + return ops.matmul(inputs, self.w) + self.b + + +class MyModel(Model): + def __init__(self, hidden_dim, output_dim): + super().__init__() + self.dense1 = MyDense(hidden_dim) + self.dense2 = MyDense(hidden_dim) + self.dense3 = MyDense(output_dim) + + def call(self, x): + x = nn.relu(self.dense1(x)) + x = nn.relu(self.dense2(x)) + return self.dense3(x) + + +def Dataset(): + for _ in range(20): + yield (mx.random.normal((32, 128)), mx.random.normal((32, 4))) + + +def loss_fn(y_true, y_pred): + return ops.sum((y_true - y_pred) ** 2) + + +model = MyModel(hidden_dim=256, output_dim=4) + +optimizer = optimizers.SGD(learning_rate=0.001) +dataset = Dataset() + +# Build model +x = mx.random.normal((1, 128)) +model(x) +# Build optimizer +optimizer.build(model.trainable_variables) + + +######### Custom MLX workflow ############### + + +def compute_loss_and_updates( + trainable_variables, non_trainable_variables, x, y +): + y_pred, non_trainable_variables = model.stateless_call( + trainable_variables, non_trainable_variables, x + ) + loss = loss_fn(y, y_pred) + return loss, non_trainable_variables + + +grad_fn = mx.value_and_grad(compute_loss_and_updates) + + + +@mx.compile +def train_step(state, data): + trainable_variables, non_trainable_variables, optimizer_variables = state + x, y = data + (loss, non_trainable_variables), grads = grad_fn( + trainable_variables, non_trainable_variables, x, y + ) + trainable_variables, optimizer_variables = optimizer.stateless_apply( + optimizer_variables, grads, trainable_variables + ) + # Return updated state + return loss, ( + trainable_variables, + non_trainable_variables, + optimizer_variables, + ) + + +# Pass lists of arrays as state for compiled train_step +trainable_variables = [tv.value for tv in model.trainable_variables] +non_trainable_variables = [ntv.value for ntv in model.non_trainable_variables] +optimizer_variables = [ov.value for ov in optimizer.variables] +state = trainable_variables, non_trainable_variables, optimizer_variables +# Training loop +for data in dataset: + loss, state = train_step(state, data) + print("Loss:", loss) + +# Post-processing model state update +trainable_variables, non_trainable_variables, optimizer_variables = state +for variable, value in zip(model.trainable_variables, trainable_variables): + variable.assign(value) +for variable, value in zip( + model.non_trainable_variables, non_trainable_variables +): + variable.assign(value) \ No newline at end of file diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py index d0fa498a5f99..680769b86c60 100644 --- a/keras/src/layers/preprocessing/stft_spectrogram_test.py +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -96,8 +96,14 @@ def test_spectrogram_channels_broadcasting(self): for i in range(audio.shape[-1]) ] - self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1)) - self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1)) + if backend.backend() == "mlx": + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1), atol=atol, rtol=rtol) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1), atol=atol, rtol=rtol) @pytest.mark.skipif( backend.backend() == "tensorflow", @@ -153,11 +159,19 @@ def test_spectrogram_channels_first(self): ) y_last = layer_last.predict(audio, verbose=0) y_first = layer_first.predict(np.transpose(audio, [0, 2, 1]), verbose=0) - self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last) - self.assertAllClose(y_expanded, np.stack(y_singles, axis=1)) + if backend.backend() == "mlx": + atol = 1e-5 + rtol = 1e-5 + else: + atol = 1e-6 + rtol = 1e-6 + self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last, atol=atol, rtol=rtol) + self.assertAllClose(y_expanded, np.stack(y_singles, axis=1), atol=atol, rtol=rtol) self.assertAllClose( y_first, np.transpose(np.concatenate(y_singles, axis=-1), [0, 2, 1]), + atol=atol, + rtol=rtol ) self.run_layer_test( layers.STFTSpectrogram, diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 0880d81dbc2c..b3b7a7201768 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -627,9 +627,10 @@ def test_fit_with_data_adapter( ): if ( dataset_kwargs.get("use_multiprocessing", False) - and backend.backend() == "jax" + and backend.backend() in ["jax", "mlx"] ): - pytest.skip("Multiprocessing not supported with JAX backend") + # note: multiprocessing works for mlx on Apple silicon + pytest.skip("Multiprocessing not supported with JAX and MLX backends") model = ExampleModel(units=3) optimizer = optimizers.Adagrad() From 8856a5ad654c0934e72e096246c992e75a892ef3 Mon Sep 17 00:00:00 2001 From: Andrew Sweet Date: Mon, 23 Jun 2025 23:53:32 -0700 Subject: [PATCH 2/2] changes for tests --- examples/demo_custom_mlx_workflow.py | 5 ++--- .../preprocessing/stft_spectrogram_test.py | 18 +++++++++++++----- keras/src/ops/nn_test.py | 4 ++++ keras/src/trainers/trainer_test.py | 11 ++++++----- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/examples/demo_custom_mlx_workflow.py b/examples/demo_custom_mlx_workflow.py index 34c7c5b3f190..c9c6957d66df 100644 --- a/examples/demo_custom_mlx_workflow.py +++ b/examples/demo_custom_mlx_workflow.py @@ -44,7 +44,7 @@ def call(self, x): x = nn.relu(self.dense1(x)) x = nn.relu(self.dense2(x)) return self.dense3(x) - + def Dataset(): for _ in range(20): @@ -83,7 +83,6 @@ def compute_loss_and_updates( grad_fn = mx.value_and_grad(compute_loss_and_updates) - @mx.compile def train_step(state, data): trainable_variables, non_trainable_variables, optimizer_variables = state @@ -119,4 +118,4 @@ def train_step(state, data): for variable, value in zip( model.non_trainable_variables, non_trainable_variables ): - variable.assign(value) \ No newline at end of file + variable.assign(value) diff --git a/keras/src/layers/preprocessing/stft_spectrogram_test.py b/keras/src/layers/preprocessing/stft_spectrogram_test.py index 680769b86c60..cba441d5fe31 100644 --- a/keras/src/layers/preprocessing/stft_spectrogram_test.py +++ b/keras/src/layers/preprocessing/stft_spectrogram_test.py @@ -102,8 +102,12 @@ def test_spectrogram_channels_broadcasting(self): else: atol = 1e-6 rtol = 1e-6 - self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1), atol=atol, rtol=rtol) - self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1), atol=atol, rtol=rtol) + self.assertAllClose( + y_last, np.concatenate(y_singles, axis=-1), atol=atol, rtol=rtol + ) + self.assertAllClose( + y_expanded, np.stack(y_singles, axis=-1), atol=atol, rtol=rtol + ) @pytest.mark.skipif( backend.backend() == "tensorflow", @@ -165,13 +169,17 @@ def test_spectrogram_channels_first(self): else: atol = 1e-6 rtol = 1e-6 - self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last, atol=atol, rtol=rtol) - self.assertAllClose(y_expanded, np.stack(y_singles, axis=1), atol=atol, rtol=rtol) + self.assertAllClose( + np.transpose(y_first, [0, 2, 1]), y_last, atol=atol, rtol=rtol + ) + self.assertAllClose( + y_expanded, np.stack(y_singles, axis=1), atol=atol, rtol=rtol + ) self.assertAllClose( y_first, np.transpose(np.concatenate(y_singles, axis=-1), [0, 2, 1]), atol=atol, - rtol=rtol + rtol=rtol, ) self.run_layer_test( layers.STFTSpectrogram, diff --git a/keras/src/ops/nn_test.py b/keras/src/ops/nn_test.py index 421a42342530..770f96b80bba 100644 --- a/keras/src/ops/nn_test.py +++ b/keras/src/ops/nn_test.py @@ -2504,6 +2504,10 @@ class NNOpsDtypeTest(testing.TestCase): FLOAT_DTYPES = dtypes.FLOAT_TYPES + if backend.backend() == "mlx": + # activations in mlx have an issue with float64 + FLOAT_DTYPES = tuple([ft for ft in FLOAT_DTYPES if ft != "float64"]) + def setUp(self): from jax.experimental import enable_x64 diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index b3b7a7201768..c0187512a5d4 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -625,12 +625,13 @@ def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch): def test_fit_with_data_adapter( self, dataset_type, dataset_kwargs={}, fit_kwargs={} ): - if ( - dataset_kwargs.get("use_multiprocessing", False) - and backend.backend() in ["jax", "mlx"] - ): + if dataset_kwargs.get( + "use_multiprocessing", False + ) and backend.backend() in ["jax", "mlx"]: # note: multiprocessing works for mlx on Apple silicon - pytest.skip("Multiprocessing not supported with JAX and MLX backends") + pytest.skip( + "Multiprocessing not supported with JAX and MLX backends" + ) model = ExampleModel(units=3) optimizer = optimizers.Adagrad()