From a7a24371adffb4f76bf98139be43e12bfa1f4524 Mon Sep 17 00:00:00 2001 From: Zixiang Zhou Date: Thu, 18 Sep 2025 13:23:04 -0700 Subject: [PATCH] Reverts changelist 793734230 PiperOrigin-RevId: 808711200 --- recml/core/data/tf_dataset_factory.py | 32 +- recml/core/ops/hstu_ops.py | 16 +- recml/core/training/keras_trainer.py | 122 ++-- recml/core/training/keras_trainer_test.py | 17 +- recml/core/training/partitioning.py | 10 +- recml/core/utils/config_test.py | 4 +- recml/core/utils/keras_utils.py | 315 ++++++++++- recml/core/utils/keras_utils_test.py | 248 ++++++--- recml/examples/DLRM_HSTU/action_encoder.py | 123 +++++ .../examples/DLRM_HSTU/action_encoder_test.py | 147 +++++ recml/examples/DLRM_HSTU/content_encoder.py | 170 ++++++ .../DLRM_HSTU/content_encoder_test.py | 106 ++++ .../contextual_interleave_preprocessor.py | 164 ++++++ .../examples/DLRM_HSTU/contextualize_mlps.py | 179 ++++++ recml/examples/DLRM_HSTU/dlrm_hstu.py | 519 ++++++++++++++++++ recml/examples/DLRM_HSTU/dlrm_hstu_test.py | 366 ++++++++++++ recml/examples/DLRM_HSTU/hstu_transducer.py | 242 ++++++++ .../DLRM_HSTU/movielens_dataloader.py | 181 ++++++ .../DLRM_HSTU/movielens_dlrm_hstu_test.py | 310 +++++++++++ recml/examples/DLRM_HSTU/multitask_module.py | 273 +++++++++ .../examples/DLRM_HSTU/positional_encoder.py | 242 ++++++++ recml/examples/DLRM_HSTU/postprocessors.py | 171 ++++++ recml/examples/DLRM_HSTU/preprocessors.py | 131 +++++ recml/examples/DLRM_HSTU/stu.py | 357 ++++++++++++ recml/examples/DLRM_HSTU/stu_test.py | 349 ++++++++++++ recml/layers/linen/sparsecore.py | 7 +- 26 files changed, 4633 insertions(+), 168 deletions(-) create mode 100644 recml/examples/DLRM_HSTU/action_encoder.py create mode 100644 recml/examples/DLRM_HSTU/action_encoder_test.py create mode 100644 recml/examples/DLRM_HSTU/content_encoder.py create mode 100644 recml/examples/DLRM_HSTU/content_encoder_test.py create mode 100644 recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py create mode 100644 recml/examples/DLRM_HSTU/contextualize_mlps.py create mode 100644 recml/examples/DLRM_HSTU/dlrm_hstu.py create mode 100644 recml/examples/DLRM_HSTU/dlrm_hstu_test.py create mode 100644 recml/examples/DLRM_HSTU/hstu_transducer.py create mode 100644 recml/examples/DLRM_HSTU/movielens_dataloader.py create mode 100644 recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py create mode 100644 recml/examples/DLRM_HSTU/multitask_module.py create mode 100644 recml/examples/DLRM_HSTU/positional_encoder.py create mode 100644 recml/examples/DLRM_HSTU/postprocessors.py create mode 100644 recml/examples/DLRM_HSTU/preprocessors.py create mode 100644 recml/examples/DLRM_HSTU/stu.py create mode 100644 recml/examples/DLRM_HSTU/stu_test.py diff --git a/recml/core/data/tf_dataset_factory.py b/recml/core/data/tf_dataset_factory.py index 7c0ead8..14db3bb 100644 --- a/recml/core/data/tf_dataset_factory.py +++ b/recml/core/data/tf_dataset_factory.py @@ -24,6 +24,7 @@ import re from typing import Any, Protocol +from absl import flags from absl import logging import jax from recml.core.utils import types @@ -162,12 +163,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]): Defaults to False. seed: An optional seed to use for deterministic shuffling / preprocessing. Defaults to None. - tf_data_service_address: An optional URI of a tf.data service to offload - preprocessing onto during training. The URI should be in the format - "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no - data service will be applied. + enable_tf_data_service: Whether to apply tf.data service for this dataset. + If True, flag `tf_data_service_address` must be set. tf_data_service_policy: Sharding policy to use for tf.data service when it is enabled. + tf_data_service_job_name: Job name to use for tf.data service. If None, the + default job name will be used. feature_spec: A mapping of feature keys to `FixedLenFeature`, `VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be used to parse the TF examples, or as context_features spec to parse TF @@ -208,7 +209,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]): tensorflow. debug: An optional boolean indicating whether to debug input boundedness. If `True`, the dataset will consist of a single batch that's cached and - infinitely repeated + infinitely repeated. """ cache_reading: bool = False @@ -231,7 +232,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]): readahead: str | None = None group_uris_by_dir: bool = False seed: int | None = None - tf_data_service_address: str | None = None + enable_tf_data_service: bool = False + tf_data_service_job_name: str | None = None tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = ( tf.data.experimental.service.ShardingPolicy.OFF ) @@ -249,7 +251,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]): debug: bool = False def __post_init__(self): - if self.tf_data_service_address is not None: + if self.enable_tf_data_service: + if flags.FLAGS.tf_data_service_address is None: + raise ValueError( + "Flag `tf_data_service_address` must be set when" + " `enable_tf_data_service` is True." + ) if self.seed is not None: raise ValueError("`seed` must be None for data service.") if self.sharding: @@ -533,23 +540,26 @@ def _maybe_apply_tf_data_service( self, dataset: tf.data.Dataset ) -> tf.data.Dataset: """Applies the tf.data service to the dataset.""" - if self.tf_data_service_address is None: + if not self.enable_tf_data_service: return dataset + tf_data_service_address = flags.FLAGS.tf_data_service_address + per_proc_batch_size = self.sharding_info.per_process_batch_size( self.global_batch_size ) logging.info( "Applying tf.data service with address %s and per replica batch" " size %s", - self.tf_data_service_address, + tf_data_service_address, per_proc_batch_size, ) return dataset.apply( tf.data.experimental.service.distribute( processing_mode=self.tf_data_service_policy, - service=self.tf_data_service_address, - job_name=f"bs_{per_proc_batch_size}", + service=tf_data_service_address, + job_name=self.tf_data_service_job_name + or "tf_data_service_shared_job_name", ) ) diff --git a/recml/core/ops/hstu_ops.py b/recml/core/ops/hstu_ops.py index 3a8df11..59fd7bd 100644 --- a/recml/core/ops/hstu_ops.py +++ b/recml/core/ops/hstu_ops.py @@ -125,9 +125,9 @@ def _apply_mask( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] snm = jnp.where(should_not_mask, 1, 0) masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) @@ -156,7 +156,7 @@ def _apply_mask( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -170,7 +170,7 @@ def _apply_mask( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -181,9 +181,9 @@ def _apply_mask( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) if masks: @@ -228,7 +228,7 @@ def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) q = q_ref[...] - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] qk = jax.lax.dot_general( q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32 ) @@ -256,7 +256,7 @@ def body(kv_compute_index, _): ) sv_dims = NN_DIM_NUMBERS - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] to_float32 = lambda x: x.astype(jnp.float32) v = to_float32(v) diff --git a/recml/core/training/keras_trainer.py b/recml/core/training/keras_trainer.py index 6c24223..a122072 100644 --- a/recml/core/training/keras_trainer.py +++ b/recml/core/training/keras_trainer.py @@ -18,6 +18,7 @@ import abc from collections.abc import Mapping import dataclasses +import functools import gc import os import time @@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str): model: The Keras model constructed by `create_model`. model_dir: The model directory passed to the trainer. """ - model.save(os.path.join(model_dir, core.KERAS_MODEL_SAVEFILE)) class KerasTrainer(core.Trainer[KerasTask]): @@ -118,6 +118,7 @@ def __init__( max_checkpoints_to_keep: int = 5, checkpoint_save_interval_epochs: int = 1, rng_seed: int = core.DEFAULT_RNG_SEED, + legacy_checkpoint_format: bool = True, ): """Initializes the instance.""" @@ -143,60 +144,77 @@ def __init__( self._steps_per_eval = steps_per_eval self._continuous_eval_timeout = continuous_eval_timeout self._steps_per_loop = steps_per_loop - self._checkpoint_manager = None self._marker_path = os.path.join( model_dir, core.TRAINING_COMPLETE_MARKER_FILE ) self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR) + self._max_checkpoints_to_keep = max_checkpoints_to_keep + self._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs + self._legacy_checkpoint_format = legacy_checkpoint_format + @functools.cached_property + def train_callbacks(self) -> list[keras.callbacks.Callback]: + """Returns the training callbacks.""" if keras.backend.backend() == "jax": - self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager( - checkpoint_dir=self._checkpoint_dir, - max_to_keep=max_checkpoints_to_keep, - save_interval_epochs=checkpoint_save_interval_epochs, - ) - self._train_callbacks = [ + if self._legacy_checkpoint_format: + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager( + checkpoint_dir=self._checkpoint_dir, + max_to_keep=self._max_checkpoints_to_keep, + save_interval_epochs=self._checkpoint_save_interval_epochs, + ) + else: + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2( + checkpoint_dir=self._checkpoint_dir, + max_to_keep=self._max_checkpoints_to_keep, + save_interval_epochs=self._checkpoint_save_interval_epochs, + ) + return [ keras_utils.EpochSummaryCallback( - log_dir=os.path.join(model_dir, core.LOG_DIR), - steps_per_epoch=steps_per_loop, + log_dir=os.path.join(self._model_dir, core.LOG_DIR), + steps_per_epoch=self._steps_per_loop, write_steps_per_second=True, ), keras_utils.EpochOrbaxCheckpointAndRestoreCallback( - checkpoint_manager=self._checkpoint_manager, + checkpoint_manager=checkpoint_manager, marker_path=self._marker_path, ), ] - self._eval_callbacks = [ + return [ + keras.callbacks.TensorBoard( + log_dir=os.path.join(self._model_dir, core.LOG_DIR), + write_steps_per_second=True, + ), + keras.callbacks.BackupAndRestore( + backup_dir=os.path.join(self._model_dir, core.BACKUP_DIR), + ), + keras.callbacks.ModelCheckpoint( + filepath=os.path.join( + self._model_dir, + core.CHECKPOINT_DIR, + "ckpt-{epoch:d}.weights.h5", + ), + save_weights_only=True, + verbose=1, + ), + ] + + @functools.cached_property + def eval_callbacks(self) -> list[keras.callbacks.Callback]: + """Returns the evaluation callbacks.""" + if keras.backend.backend() == "jax": + return [ keras_utils.EpochSummaryCallback( - log_dir=os.path.join(model_dir, core.LOG_DIR), - steps_per_epoch=steps_per_loop, + log_dir=os.path.join(self._model_dir, core.LOG_DIR), + steps_per_epoch=self._steps_per_loop, write_steps_per_second=False, ), ] - else: - self._checkpoint_manager = None - self._train_callbacks = [ - keras.callbacks.TensorBoard( - log_dir=os.path.join(model_dir, core.LOG_DIR), - write_steps_per_second=True, - ), - keras.callbacks.BackupAndRestore( - backup_dir=os.path.join(model_dir, core.BACKUP_DIR), - ), - keras.callbacks.ModelCheckpoint( - filepath=os.path.join( - model_dir, core.CHECKPOINT_DIR, "ckpt-{epoch:d}.weights.h5" - ), - save_weights_only=True, - verbose=1, - ), - ] - self._eval_callbacks = [ - keras.callbacks.TensorBoard( - log_dir=os.path.join(model_dir, core.LOG_DIR), - write_steps_per_second=True, - ), - ] + return [ + keras.callbacks.TensorBoard( + log_dir=os.path.join(self._model_dir, core.LOG_DIR), + write_steps_per_second=True, + ), + ] def _maybe_get_model_kws( self, task: KerasTask, dataset: tf.data.Dataset @@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs: dataset, epochs=self._train_epochs, steps_per_epoch=self._steps_per_loop, - callbacks=self._train_callbacks, + callbacks=self.train_callbacks, ) model.summary(print_fn=logging.info) @@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs: if keras.backend.backend() == "jax": [tb_cbk] = [ cbk - for cbk in self._eval_callbacks + for cbk in self.eval_callbacks if isinstance(cbk, keras_utils.EpochSummaryCallback) ] epoch_start_time = time.time() history = model.evaluate( dataset, steps=self._steps_per_eval, - callbacks=self._eval_callbacks, + callbacks=self.eval_callbacks, return_dict=True, ) epoch_dt = time.time() - epoch_start_time @@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs: return model.evaluate( dataset, steps=self._steps_per_eval, - callbacks=self._eval_callbacks, + callbacks=self.eval_callbacks, ) def train_and_evaluate(self, task: KerasTask) -> core.Logs: @@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs: steps_per_epoch=self._steps_per_loop, # Explicitly set to None for deterministic evaluation. validation_steps=None, - callbacks=self._train_callbacks, + callbacks=self.train_callbacks, ) model.summary(print_fn=logging.info) @@ -308,7 +326,10 @@ def timeout_fn() -> bool: else: steps_msg = "running complete evaluation..." + use_legacy_checkpoint_format = self._legacy_checkpoint_format + class _RestoreCallback(keras.callbacks.Callback): + """Callback for restoring the model from the latest checkpoint.""" def __init__( self, @@ -319,9 +340,14 @@ def __init__( self._epoch = epoch def on_test_begin(self, logs: Mapping[str, Any] | None = None): - keras_utils.restore_keras_model( - model, self._checkpoint_dir, step=self._epoch - ) + if use_legacy_checkpoint_format: + keras_utils.restore_keras_model( + model, self._checkpoint_dir, step=self._epoch + ) + else: + keras_utils.restore_keras_checkpoint( + self._checkpoint_dir, model=model, epoch=self._epoch + ) history = None for epoch in ocp.checkpoint_utils.checkpoints_iterator( @@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None): restore_callback = _RestoreCallback(self._checkpoint_dir, epoch) [tb_cbk] = [ cbk - for cbk in self._eval_callbacks + for cbk in self.eval_callbacks if isinstance(cbk, keras_utils.EpochSummaryCallback) ] try: @@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None): history = model.evaluate( eval_dataset, steps=self._steps_per_eval, - callbacks=[restore_callback] + self._eval_callbacks, + callbacks=[restore_callback] + self.eval_callbacks, return_dict=True, ) diff --git a/recml/core/training/keras_trainer_test.py b/recml/core/training/keras_trainer_test.py index 46844da..b800a89 100644 --- a/recml/core/training/keras_trainer_test.py +++ b/recml/core/training/keras_trainer_test.py @@ -59,11 +59,23 @@ def setUp(self): "mode": core.Experiment.Mode.TRAIN_AND_EVAL, }, { - "testcase_name": "continuous_eval", + "testcase_name": "continuous_eval_", "mode": core.Experiment.Mode.CONTINUOUS_EVAL, }, + { + "testcase_name": "train_and_eval_legacy_checkpoint_format", + "mode": core.Experiment.Mode.TRAIN_AND_EVAL, + "legacy_checkpoint_format": True, + }, + { + "testcase_name": "continuous_eval_legacy_checkpoint_format", + "mode": core.Experiment.Mode.CONTINUOUS_EVAL, + "legacy_checkpoint_format": True, + }, ) - def test_keras_task_and_trainer(self, mode: str): + def test_keras_task_and_trainer( + self, mode: str, legacy_checkpoint_format: bool = False + ): if keras.backend.backend() == "jax": distribution = keras.distribution.DataParallel() else: @@ -78,6 +90,7 @@ def test_keras_task_and_trainer(self, mode: str): steps_per_loop=2, model_dir=self.create_tempdir().full_path, continuous_eval_timeout=5, + legacy_checkpoint_format=legacy_checkpoint_format, ) experiment = core.Experiment(_KerasTask(), trainer) diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..eabce4a 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): if abstract_batch is not None: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -117,7 +117,7 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +130,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +138,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): return step_fn(batch, state) return _wrapped_step @@ -217,7 +217,7 @@ def __init__( def mesh_context_manager( self, ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh + return jax.set_mesh def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: diff --git a/recml/core/utils/config_test.py b/recml/core/utils/config_test.py index 639270f..2ea6df5 100644 --- a/recml/core/utils/config_test.py +++ b/recml/core/utils/config_test.py @@ -72,8 +72,8 @@ class ConfigTest(parameterized.TestCase): 'testcase_name': 'relative_fiddler', 'args': [ f'config:{_TEST_MODULE_NAME}.config_1', - 'fiddler:fiddler_1', - 'fiddler:fiddler_2(value=4)', + f'fiddler:{_TEST_MODULE_NAME}.fiddler_1', + f'fiddler:{_TEST_MODULE_NAME}.fiddler_2(value=4)', ], 'expected_config': fdl.Config( _Object, diff --git a/recml/core/utils/keras_utils.py b/recml/core/utils/keras_utils.py index 1ee6b39..be7a755 100644 --- a/recml/core/utils/keras_utils.py +++ b/recml/core/utils/keras_utils.py @@ -14,6 +14,7 @@ """Utilities for training Keras models on Jax backend.""" from collections.abc import Mapping +import functools from typing import Any from absl import logging @@ -22,6 +23,12 @@ import orbax.checkpoint as ocp import tensorflow as tf + +STATE_CHECKPOINT_KEY = "state" +TRAINABLE_VARIABLES_KEY = "trainable_variables" +NON_TRAINABLE_VARIABLES_KEY = "non_trainable_variables" +OPTIMIZER_VARIABLES_KEY = "optimizer_variables" +CONFIG_CHECKPOINT_KEY = "config" ORBAX_CHECKPOINT_DEFAULT_KEY = "default" @@ -34,6 +41,303 @@ def _assert_variables_built(model: keras.Model): ) +def _assert_all_layers_built(model: keras.Model): + flattened_layers = model._flatten_layers(include_self=True) # pylint: disable=protected-access + if not all(layer.built for layer in flattened_layers): + raise ValueError( + "To save or restore a checkpoint with a Keras model, the model and" + " all of its layers must be built. The layers that are not built" + " properly are the following:" + f" {[layer for layer in flattened_layers if not layer.built]}." + ) + + +def _to_shape_dtype_struct(x: keras.Variable) -> jax.ShapeDtypeStruct: + if not isinstance(x, keras.Variable): + raise ValueError(f"Expected a `keras.Variable`, got {type(x)}.") + return jax.ShapeDtypeStruct( + shape=x.value.shape, + dtype=x.value.dtype, + sharding=x.value.sharding, + ) + + +class KerasOrbaxCheckpointManagerV2(ocp.CheckpointManager): + """An Orbax checkpoint manager for Keras 3.""" + + def __init__( + self, + checkpoint_dir: str, + max_to_keep: int = 5, + save_interval_epochs: int = 1, + ): + """Initializes a KerasOrbaxCheckpointManager. + + Args: + checkpoint_dir: The directory to save checkpoints to. + max_to_keep: The maximum number of checkpoints to keep. + save_interval_epochs: The interval (in epochs) to save checkpoints. + """ + if keras.backend.backend() != "jax": + raise ValueError( + "`KerasOrbaxCheckpointManagerV2` is only supported on a `jax`" + " backend." + ) + super().__init__( + directory=checkpoint_dir, + options=ocp.CheckpointManagerOptions( + save_interval_steps=save_interval_epochs, + max_to_keep=max_to_keep, + ), + ) + + def save_model_variables( + self, + model: keras.Model, + epoch: int, + logs: Mapping[str, Any] | None = None, + ): + """Saves the model variables and optimizer variables to a checkpoint.""" + _assert_variables_built(model) + _assert_all_layers_built(model) + + if not model._jax_state_synced: # pylint: disable=protected-access + model.jax_state_sync() + + variables = { + TRAINABLE_VARIABLES_KEY: model.trainable_variables, + NON_TRAINABLE_VARIABLES_KEY: model.non_trainable_variables, + OPTIMIZER_VARIABLES_KEY: model.optimizer.variables, + } + state = jax.tree.map(lambda x: x.value, variables) + config = keras.utils.serialize_keras_object(model) + + logging.info("Saving checkpoint for epoch %s...", epoch) + self.save( + step=epoch, + args=ocp.args.Composite(**{ + STATE_CHECKPOINT_KEY: ocp.args.StandardSave(state), + CONFIG_CHECKPOINT_KEY: ocp.args.JsonSave(config), + }), + metrics=logs, + ) + + def restore_model_variables(self, model: keras.Model, epoch: int): + """Restores the model variables and optimizer variables during training.""" + + _assert_variables_built(model) + _assert_all_layers_built(model) + + if not model._jax_state_synced: # pylint: disable=protected-access + model.jax_state_sync() + + variables = { + TRAINABLE_VARIABLES_KEY: model.trainable_variables, + NON_TRAINABLE_VARIABLES_KEY: model.non_trainable_variables, + OPTIMIZER_VARIABLES_KEY: model.optimizer.variables, + } + + # TODO(zixiangzhou): Update variables to use a nested dictionary and index + # map instead of flattened list. + + # Construct abstract variables to ensure the checkpoint is restored with + # the same sharding as the current variables. This is so we can delete the + # variables from device memory to reduce peak memory usage. + abstract_variables = jax.tree.map(_to_shape_dtype_struct, variables) + for var in jax.tree.flatten(variables)[0]: + var.value.delete() + var._value = None # pylint: disable=protected-access + + logging.info("Restoring checkpoint for epoch %s...", epoch) + + restored_items = self.restore( + step=epoch, + args=ocp.args.Composite(**{ + STATE_CHECKPOINT_KEY: ocp.args.StandardRestore(abstract_variables) + }), + ) + restored_variables = restored_items[STATE_CHECKPOINT_KEY] + + logging.info("Restored checkpoint for epoch %s.", epoch) + + model._initial_epoch = epoch + 1 # pylint: disable=protected-access + + keras.tree.assert_same_structure(variables, restored_variables) + for var, restored_var in zip( + jax.tree.flatten(variables)[0], jax.tree.flatten(restored_variables)[0] + ): + var._value = restored_var # pylint: disable=protected-access + + +def restore_keras_checkpoint( + checkpoint_dir: str, + *, + model: keras.Model | None = None, + epoch: int | None = None, + compile: bool = False, # pylint: disable=redefined-builtin + restore_optimizer_vars: bool = False, +) -> keras.Model: + """Restores a Keras 3 Jax backend model from an Orbax checkpoint.""" + + if keras.backend.backend() != "jax": + raise ValueError( + "This function only supports restoring a Keras 3 Jax backend model." + ) + if restore_optimizer_vars and model is None: + raise ValueError( + "To use `restore_keras_checkpoint` with `restore_optimizer_vars` set to" + " True, a model must be provided." + ) + + metadata = ocp.path.step.latest_step_metadata( + checkpoint_dir, ocp.path.step.standard_name_format() + ) + if metadata is None: + raise FileNotFoundError( + f"No checkpoints found in {checkpoint_dir}. Please ensure that the" + " checkpoint directory contains Orbax checkpoints." + ) + if epoch is None: + epoch = metadata.step + elif epoch not in ocp.path.step.checkpoint_steps(checkpoint_dir): + raise ValueError( + f"Step {epoch} not found in {checkpoint_dir}. Please ensure you specify" + " a valid step. Available steps:" + f" {ocp.path.step.checkpoint_steps(checkpoint_dir)}" + ) + + checkpoint_path = ocp.path.step.build_step_path( + checkpoint_dir, ocp.path.step.standard_name_format(), epoch + ) + + if model is None: + cfg = {**load_keras_model_config(checkpoint_dir, epoch=epoch)} + if not compile and "compile_config" in cfg: + cfg.pop("compile_config") + + model: keras.Model = keras.utils.deserialize_keras_object(cfg) + if not model.built: + if "build_config" not in cfg: + raise ValueError( + "To use `restore_keras_checkpoint` on a model checkpoint without" + " passing a model the `build_config` must be present in the config." + " Make sure the you have implemented `get_build_config` correctly." + " Generally, you shouldn't need to do this and the default" + " implementation should work for most cases." + ) + model.build_from_config(cfg["build_config"]) + elif not model._jax_state_synced: # pylint: disable=protected-access + model.jax_state_sync() + + _assert_all_layers_built(model) + + variables = { + TRAINABLE_VARIABLES_KEY: model.trainable_variables, + NON_TRAINABLE_VARIABLES_KEY: model.non_trainable_variables, + } + if restore_optimizer_vars: + if not model.optimizer.built: + raise ValueError( + "To use `restore_keras_checkpoint` on an existing model with" + " `restore_optimizer_vars` set to True, the optimizer must be" + " built." + ) + variables[OPTIMIZER_VARIABLES_KEY] = model.optimizer.variables + + # TODO(zixiangzhou): Update variables to use a nested dictionary and index map + # instead of flattened list. + + # Construct abstract variables to ensure the checkpoint is restored with + # the same sharding as the current variables. + abstract_state = jax.tree.map(_to_shape_dtype_struct, variables) + + # Delete the variables from device memory to reduce peak memory usage. + for var in jax.tree.flatten(variables)[0]: + var.value.delete() + var._value = None # pylint: disable=protected-access + + # TODO(aahil): Look into converging the logic here with the checkpointing + # logic in KerasOrbaxCheckpointManagerV2. + checkpointer = ocp.Checkpointer( + ocp.CompositeCheckpointHandler(**{ + STATE_CHECKPOINT_KEY: ocp.handlers.PyTreeCheckpointHandler( + restore_concurrent_gb=96, + ), + }) + ) + restored_state = checkpointer.restore( + checkpoint_path, + args=ocp.args.Composite(**{ + STATE_CHECKPOINT_KEY: ocp.args.PyTreeRestore( + abstract_state, + transforms={}, + restore_args=ocp.checkpoint_utils.construct_restore_args( + abstract_state + ), + ), + }), + )[STATE_CHECKPOINT_KEY] + checkpointer.close() + + # TODO(zixiangzhou): Unflatten the variables based on index here. + keras.tree.assert_same_structure(variables, restored_state) + for var, restored_var in zip( + jax.tree.flatten(variables)[0], jax.tree.flatten(restored_state)[0] + ): + var._value = restored_var # pylint: disable=protected-access + + if restore_optimizer_vars: + model._initial_epoch = epoch + 1 # pylint: disable=protected-access + + return model + + +@functools.lru_cache +def load_keras_model_config( + checkpoint_dir: str, epoch: int | None = None +) -> Mapping[str, Any]: + """Loads a Keras model from a checkpoint directory.""" + if keras.backend.backend() != "jax": + raise ValueError( + "This function only supports loading a Keras 3 Jax backend model." + ) + + metadata = ocp.path.step.latest_step_metadata( + checkpoint_dir, ocp.path.step.standard_name_format() + ) + if metadata is None: + raise FileNotFoundError( + f"No checkpoints found in {checkpoint_dir}. Please ensure that the" + " checkpoint directory contains Orbax checkpoints." + ) + if epoch is None: + epoch = metadata.step + elif epoch not in ocp.path.step.checkpoint_steps(checkpoint_dir): + raise ValueError( + f"Step {epoch} not found in {checkpoint_dir}. Please ensure you specify" + " a valid step. Available steps:" + f" {ocp.path.step.checkpoint_steps(checkpoint_dir)}" + ) + + checkpoint_path = ocp.path.step.build_step_path( + checkpoint_dir, ocp.path.step.standard_name_format(), epoch + ) + + json_checkpointer = ocp.Checkpointer( + ocp.CompositeCheckpointHandler( + **{CONFIG_CHECKPOINT_KEY: ocp.handlers.JsonCheckpointHandler()} + ) + ) + cfg = json_checkpointer.restore( + checkpoint_path, + args=ocp.args.Composite( + **{CONFIG_CHECKPOINT_KEY: ocp.args.JsonRestore()} + ), + )[CONFIG_CHECKPOINT_KEY] + json_checkpointer.close() + return cfg + + class KerasOrbaxCheckpointManager(ocp.CheckpointManager): """An Orbax checkpoint manager for Keras 3.""" @@ -142,7 +446,9 @@ class EpochOrbaxCheckpointAndRestoreCallback(keras.callbacks.Callback): def __init__( self, - checkpoint_manager: KerasOrbaxCheckpointManager, + checkpoint_manager: ( + KerasOrbaxCheckpointManager | KerasOrbaxCheckpointManagerV2 + ), marker_path: str | None = None, ): if keras.backend.backend() != "jax": @@ -188,6 +494,9 @@ def restore_keras_model( ): """Restores a Keras 3 Jax backend model from an Orbax checkpoint. + This is only compatible with `KerasOrbaxCheckpointManager`. If you are using + `KerasOrbaxCheckpointManagerV2`, use `restore_keras_checkpoint` instead. + Args: model: The Keras model to restore. checkpoint_dir: The directory containing the Orbax checkpoints. @@ -203,8 +512,8 @@ def restore_keras_model( restore_iterations: Whether to restore the model's iterations. If `True` then the model will continue training from the iteration the checkpoint was saved at. This is an optimizer variable used for controlling the - learning rate schedule. This is not supported if restore_optimizer_vars - is `False`. + learning rate schedule. This is not supported if restore_optimizer_vars is + `False`. Raises: FileNotFoundError: If no checkpoints are found in the checkpoint directory. diff --git a/recml/core/utils/keras_utils_test.py b/recml/core/utils/keras_utils_test.py index 010707a..6e70c0a 100644 --- a/recml/core/utils/keras_utils_test.py +++ b/recml/core/utils/keras_utils_test.py @@ -14,6 +14,7 @@ """Tests or utilities.""" from collections.abc import Sequence +import json from absl import flags from absl.testing import absltest @@ -25,12 +26,15 @@ import numpy as np from recml.core.utils import keras_utils -_LEARNING_RATE_SCHEDULE = keras.optimizers.schedules.PolynomialDecay( - initial_learning_rate=0.1, - decay_steps=100, - end_learning_rate=0.01, - power=1.0, -) + +def _create_dummy_inputs() -> dict[str, jax.Array]: + k1, k2, k3, k4 = jax.random.split(jax.random.key(42), 4) + return { + "token_ids": jax.random.randint(k1, (64, 128), minval=0, maxval=2048), + "segment_ids": jax.random.randint(k2, (64, 128), minval=0, maxval=8), + "padding_mask": jax.random.uniform(k3, (64, 128)), + "mask_positions": jax.random.randint(k4, (64, 20), minval=0, maxval=32), + } def _create_model(input_shapes: Sequence[int]) -> keras.Model: @@ -46,7 +50,14 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model: dropout=0.1, ) ) - optimizer = keras.optimizers.Adam(learning_rate=_LEARNING_RATE_SCHEDULE) + optimizer = keras.optimizers.Adam( + learning_rate=keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=0.1, + decay_steps=100, + end_learning_rate=0.01, + power=1.0, + ) + ) loss = keras.losses.SparseCategoricalCrossentropy() metrics = [keras.metrics.SparseCategoricalAccuracy()] model.compile(optimizer, loss, weighted_metrics=metrics) @@ -63,6 +74,132 @@ def setUp(self): if not flags.FLAGS.is_parsed(): flags.FLAGS.mark_as_parsed() + @parameterized.named_parameters( + { + "testcase_name": "single_core", + "data_parallel": False, + "restore_with_checkpointer": True, + }, + { + "testcase_name": "data_parallel", + "data_parallel": True, + "restore_with_checkpointer": True, + }, + { + "testcase_name": "restore_without_checkpointer_single_core", + "data_parallel": False, + "restore_with_checkpointer": False, + }, + { + "testcase_name": "restore_without_checkpointer_data_parallel", + "data_parallel": True, + "restore_with_checkpointer": False, + }, + ) + def test_keras_orbax_checkpointer_v2( + self, data_parallel: bool, restore_with_checkpointer: bool + ): + if data_parallel: + keras.distribution.set_distribution(keras.distribution.DataParallel()) + else: + keras.distribution.set_distribution(None) + + checkpoint_dir = self.create_tempdir().full_path + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2( + checkpoint_dir, max_to_keep=5 + ) + dummy_inputs = _create_dummy_inputs() + + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + state = ( + [v.value for v in bert_pretrainer.trainable_variables], + [v.value for v in bert_pretrainer.non_trainable_variables], + [v.value for v in bert_pretrainer.optimizer.variables], + ) + checkpoint_manager.save_model_variables(bert_pretrainer, 0) + checkpoint_manager.wait_until_finished() + + preds = bert_pretrainer(dummy_inputs) + + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + if restore_with_checkpointer: + checkpoint_manager.restore_model_variables(bert_pretrainer, 0) + else: + keras_utils.restore_keras_checkpoint( + checkpoint_dir, model=bert_pretrainer, restore_optimizer_vars=True + ) + + checkpoint_manager.close() + + restored_state = ( + [v.value for v in bert_pretrainer.trainable_variables], + [v.value for v in bert_pretrainer.non_trainable_variables], + [v.value for v in bert_pretrainer.optimizer.variables], + ) + preds_after_restoration = bert_pretrainer(dummy_inputs) + + keras.tree.assert_same_structure(state, restored_state) + for expected, observed in zip( + jax.tree.flatten(state)[0], jax.tree.flatten(restored_state)[0] + ): + # Ensures the objects are different but the values are the same. + self.assertNotEqual(id(expected), id(observed)) + self.assertEqual(expected.shape, observed.shape) + self.assertEqual(expected.dtype, observed.dtype) + self.assertEqual(expected.sharding, observed.sharding) + np.testing.assert_allclose(observed, expected) + + # Ensures predictions are identical. + np.testing.assert_allclose(preds, preds_after_restoration) + + def test_restore_keras_checkpoint(self): + dummy_inputs = _create_dummy_inputs() + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + preds = bert_pretrainer(dummy_inputs) + + checkpoint_dir = self.create_tempdir().full_path + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2( + checkpoint_dir + ) + checkpoint_manager.save_model_variables(bert_pretrainer, epoch=1) + checkpoint_manager.close() + + restored_model = keras_utils.restore_keras_checkpoint(checkpoint_dir) + preds_after_restoration = restored_model(dummy_inputs) + + for expected, observed in zip( + [v.value for v in bert_pretrainer.variables], + [v.value for v in restored_model.variables], + ): + # Ensures the objects are different but the values are the same. + self.assertNotEqual(id(expected), id(observed)) + self.assertEqual(expected.shape, observed.shape) + self.assertEqual(expected.dtype, observed.dtype) + self.assertEqual(expected.sharding, observed.sharding) + np.testing.assert_allclose(observed, expected) + + self.assertDictEqual( + bert_pretrainer.get_config(), restored_model.get_config() + ) + np.testing.assert_allclose(preds, preds_after_restoration) + + def test_load_keras_model_config(self): + dummy_inputs = _create_dummy_inputs() + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + config = keras.utils.serialize_keras_object(bert_pretrainer) + config = json.loads(json.dumps(config)) # Converts tuples to lists. + + checkpoint_dir = self.create_tempdir().full_path + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2( + checkpoint_dir + ) + checkpoint_manager.save_model_variables(bert_pretrainer, epoch=1) + checkpoint_manager.close() + + self.assertDictEqual( + config, keras_utils.load_keras_model_config(checkpoint_dir, epoch=1) + ) + @parameterized.named_parameters( { "testcase_name": "single_core", @@ -80,7 +217,7 @@ def setUp(self): "restore_with_checkpointer": False, }, { - "testcase_name": "restore_without_checkpointer_model_parallel", + "testcase_name": "restore_without_checkpointer_single_core", "data_parallel": False, "restore_with_checkpointer": False, }, @@ -90,44 +227,14 @@ def test_keras_orbax_checkpointer( ): if data_parallel: keras.distribution.set_distribution(keras.distribution.DataParallel()) + else: + keras.distribution.set_distribution(None) + checkpoint_dir = self.create_tempdir().full_path - checkpointer = keras_utils.KerasOrbaxCheckpointManager( + checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager( checkpoint_dir, max_to_keep=5 ) - epoch = 1 - dummy_inputs = { - "token_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=50_000 - ), - "segment_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=7 - ), - "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), - "mask_positions": jax.random.randint( - jax.random.key(0), (64, 20), minval=0, maxval=128 - ), - } - - def _create_model(input_shapes: Sequence[int]) -> keras.Model: - model = keras_hub.models.BertMaskedLM( - backbone=keras_hub.models.BertBackbone( - vocabulary_size=50_000, - num_layers=10, - num_heads=8, - hidden_dim=256, - intermediate_dim=3072, - max_sequence_length=128, - num_segments=7, - dropout=0.1, - ) - ) - optimizer = keras.optimizers.Adam(learning_rate=0.1) - loss = keras.losses.SparseCategoricalCrossentropy() - metrics = [keras.metrics.SparseCategoricalAccuracy()] - model.compile(optimizer, loss, weighted_metrics=metrics) - model.build(input_shapes) - optimizer.build(model.trainable_variables) - return model + dummy_inputs = _create_dummy_inputs() bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) state = ( @@ -135,14 +242,18 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model: [v.value for v in bert_pretrainer.non_trainable_variables], [v.value for v in bert_pretrainer.optimizer.variables], ) - checkpointer.save_model_variables(bert_pretrainer, epoch) + checkpoint_manager.save_model_variables(bert_pretrainer, epoch=1) + checkpoint_manager.wait_until_finished() preds = bert_pretrainer(dummy_inputs) bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) if restore_with_checkpointer: - checkpointer.restore_model_variables(bert_pretrainer, epoch) + checkpoint_manager.restore_model_variables(bert_pretrainer, epoch=1) else: keras_utils.restore_keras_model(bert_pretrainer, checkpoint_dir) + + checkpoint_manager.close() + restored_state = ( [v.value for v in bert_pretrainer.trainable_variables], [v.value for v in bert_pretrainer.non_trainable_variables], @@ -161,24 +272,12 @@ def _close(a: jax.Array, b: jax.Array): self.assertTrue(_close(preds, preds_after_restoration)) def test_restore_keras_model_error_cases(self): + dummy_inputs = _create_dummy_inputs() + bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) + checkpoint_dir = self.create_tempdir().full_path checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir) - epoch = 2 - dummy_inputs = { - "token_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=50_000 - ), - "segment_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=7 - ), - "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), - "mask_positions": jax.random.randint( - jax.random.key(0), (64, 20), minval=0, maxval=128 - ), - } - - bert_pretrainer = _create_model(jax.tree.map(jnp.shape, dummy_inputs)) - checkpointer.save_model_variables(bert_pretrainer, epoch) + checkpointer.save_model_variables(bert_pretrainer, epoch=2) checkpointer.wait_until_finished() with self.assertRaises(ValueError): keras_utils.restore_keras_model(bert_pretrainer, checkpoint_dir, step=0) @@ -202,18 +301,7 @@ def test_metrics_variables_checkpointing( checkpoint_dir = self.create_tempdir().full_path checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir) epoch = 1 - dummy_inputs = { - "token_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=50_000 - ), - "segment_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=7 - ), - "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), - "mask_positions": jax.random.randint( - jax.random.key(0), (64, 20), minval=0, maxval=128 - ), - } + dummy_inputs = _create_dummy_inputs() source_bert_pretrainer = _create_model( jax.tree.map(jnp.shape, dummy_inputs) @@ -308,19 +396,7 @@ def test_restore_keras_model_with_different_options( checkpoint_dir = self.create_tempdir().full_path checkpointer = keras_utils.KerasOrbaxCheckpointManager(checkpoint_dir) epoch = 1 - dummy_inputs = { - "token_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=50_000 - ), - "segment_ids": jax.random.randint( - jax.random.key(0), (64, 128), minval=0, maxval=7 - ), - "padding_mask": jax.random.uniform(jax.random.key(0), (64, 128)), - "mask_positions": jax.random.randint( - jax.random.key(0), (64, 20), minval=0, maxval=128 - ), - } - + dummy_inputs = _create_dummy_inputs() source_bert_pretrainer = _create_model( jax.tree.map(jnp.shape, dummy_inputs) ) diff --git a/recml/examples/DLRM_HSTU/action_encoder.py b/recml/examples/DLRM_HSTU/action_encoder.py new file mode 100644 index 0000000..1caed5b --- /dev/null +++ b/recml/examples/DLRM_HSTU/action_encoder.py @@ -0,0 +1,123 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of the ActionEncoder module.""" + +from typing import Dict, List, Optional, Tuple + +import flax.linen as nn +from flax.linen import initializers +import jax +import jax.numpy as jnp + + +class ActionEncoder(nn.Module): + """Encodes categorical actions and continuous watch times into a fixed-size embedding. + + assumes dense tensors of shape (batch_size, sequence_length) for all inputs. + """ + + action_embedding_dim: int + action_feature_name: str + action_weights: List[int] + watchtime_feature_name: str = "" + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None + + def setup(self): + """Initializes parameters and constants for the module.""" + wt_thresholds_and_weights = ( + self.watchtime_to_action_thresholds_and_weights or [] + ) + + self.combined_action_weights = jnp.array( + list(self.action_weights) + [w for _, w in wt_thresholds_and_weights] + ) + + self.num_action_types: int = ( + len(self.action_weights) + len(wt_thresholds_and_weights) + ) + + self.action_embedding_table = self.param( + "action_embedding_table", + initializers.normal(stddev=0.1), + (self.num_action_types, self.action_embedding_dim), + ) + + self.target_action_embedding_table = self.param( + "target_action_embedding_table", + initializers.normal(stddev=0.1), + (1, self.output_embedding_dim), + ) + + @property + def output_embedding_dim(self) -> int: + """The dimension of the final output embedding.""" + num_watchtime_actions = ( + len(self.watchtime_to_action_thresholds_and_weights) + if self.watchtime_to_action_thresholds_and_weights + else 0 + ) + num_action_types = len(self.action_weights) + num_watchtime_actions + return self.action_embedding_dim * num_action_types + + def __call__( + self, + seq_payloads: Dict[str, jax.Array], + is_target_mask: jax.Array, + ) -> jax.Array: + """Processes a batch of sequences to generate action embeddings. + + Args: + seq_payloads: A dictionary of feature names to dense tensors of shape + `(batch_size, sequence_length)`. + is_target_mask: A boolean tensor of shape `(batch_size, + sequence_length)` where `True` indicates a target item. + + Returns: + A dense tensor of action embeddings of shape + `(batch_size, sequence_length, output_embedding_dim)`. + """ + + seq_actions = seq_payloads[self.action_feature_name] + + wt_thresholds_and_weights = ( + self.watchtime_to_action_thresholds_and_weights or [] + ) + if wt_thresholds_and_weights: + watchtimes = seq_payloads[self.watchtime_feature_name] + for threshold, weight in wt_thresholds_and_weights: + watch_action = (watchtimes >= threshold).astype(jnp.int64) * weight + seq_actions = jnp.bitwise_or(seq_actions, watch_action) + + exploded_actions = ( + jnp.bitwise_and(seq_actions[..., None], self.combined_action_weights) + > 0 + ) + + history_embeddings = ( + exploded_actions[..., None] * self.action_embedding_table + ).reshape(*seq_actions.shape, -1) + + target_embeddings = jnp.broadcast_to( + self.target_action_embedding_table, history_embeddings.shape + ) + + final_embeddings = jnp.where( + is_target_mask[..., None], + target_embeddings, + history_embeddings, + ) + + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/action_encoder_test.py b/recml/examples/DLRM_HSTU/action_encoder_test.py new file mode 100644 index 0000000..6706bd2 --- /dev/null +++ b/recml/examples/DLRM_HSTU/action_encoder_test.py @@ -0,0 +1,147 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt +# from third_party.py.pybase import googletest +from absl.testing import absltest +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder + + +class ActionEncoderJaxTest(absltest.TestCase): + def test_forward_and_backward(self) -> None: + """Tests the ActionEncoder's forward pass logic and differentiability.""" + + batch_size = 2 + max_seq_len = 6 + action_embedding_dim = 32 + action_weights = [1, 2, 4, 8, 16] + watchtime_to_action_thresholds_and_weights = [ + (30, 32), (60, 64), (100, 128), + ] + num_action_types = len(action_weights) + len( + watchtime_to_action_thresholds_and_weights + ) + output_dim = action_embedding_dim * num_action_types + combined_action_weights = action_weights + [ + w for _, w in watchtime_to_action_thresholds_and_weights + ] + + enabled_actions = [ + [0], # Seq 1, Item 1 + [0, 1], # Seq 1, Item 2 + [1, 3, 4], # Seq 1, Item 3 + [1, 2, 3, 4], # Seq 1, Item 4 + [1, 2], # Seq 2, Item 1 + [2], # Seq 2, Item 2 + ] + watchtimes_flat = [40, 20, 110, 31, 26, 55] + + # Add actions based on watchtime thresholds + for i, wt in enumerate(watchtimes_flat): + for j, (threshold, _) in enumerate( + watchtime_to_action_thresholds_and_weights + ): + if wt > threshold: + enabled_actions[i].append(j + len(action_weights)) + + actions_flat = [ + sum([combined_action_weights[t] for t in x]) for x in enabled_actions + ] + + padded_actions = np.zeros((batch_size, max_seq_len), dtype=np.int64) + padded_watchtimes = np.zeros((batch_size, max_seq_len), dtype=np.int64) + + padded_actions[0, :4] = actions_flat[0:4] + padded_actions[1, :2] = actions_flat[4:6] + padded_watchtimes[0, :4] = watchtimes_flat[0:4] + padded_watchtimes[1, :2] = watchtimes_flat[4:6] + + is_target_mask = np.zeros((batch_size, max_seq_len), dtype=bool) + is_target_mask[0, 4:6] = True + is_target_mask[1, 2] = True + + padding_mask = np.zeros((batch_size, max_seq_len), dtype=bool) + padding_mask[0, :6] = True + padding_mask[1, :3] = True + + seq_payloads = { + "watchtimes": jnp.array(padded_watchtimes), + "actions": jnp.array(padded_actions), + } + + encoder = ActionEncoder( + watchtime_feature_name="watchtimes", + action_feature_name="actions", + action_weights=action_weights, + watchtime_to_action_thresholds_and_weights=( + watchtime_to_action_thresholds_and_weights + ), + action_embedding_dim=action_embedding_dim, + ) + + key = jax.random.PRNGKey(0) + variables = encoder.init(key, seq_payloads, is_target_mask) + params = variables["params"] + + action_embeddings = encoder.apply( + variables, seq_payloads, is_target_mask + ) + + self.assertEqual( + action_embeddings.shape, (batch_size, max_seq_len, output_dim) + ) + + action_table = params["action_embedding_table"] + target_table_flat = params["target_action_embedding_table"] + target_table = target_table_flat.reshape(num_action_types, -1) + + history_item_idx = 0 + for b in range(batch_size): + for s in range(max_seq_len): + if not padding_mask[b, s]: + npt.assert_allclose(action_embeddings[b, s], 0, atol=1e-6) + continue + + embedding = action_embeddings[b, s].reshape(num_action_types, -1) + + if is_target_mask[b, s]: + npt.assert_allclose(embedding, target_table, atol=1e-6) + else: + current_enabled = enabled_actions[history_item_idx] + for atype in range(num_action_types): + if atype in current_enabled: + npt.assert_allclose( + embedding[atype], action_table[atype], atol=1e-6 + ) + else: + npt.assert_allclose(embedding[atype], + jnp.zeros_like(embedding[atype]), + atol=1e-6) + history_item_idx += 1 + + def loss_fn(p): + return encoder.apply({"params": p}, seq_payloads, is_target_mask).sum() + + grads = jax.grad(loss_fn)(params) + self.assertIsNotNone(grads) + self.assertFalse(np.all(np.isclose(grads["action_embedding_table"], 0))) + self.assertFalse(np.all( + np.isclose(grads["target_action_embedding_table"], 0) + )) + + +if __name__ == "__main__": + absltest.main() diff --git a/recml/examples/DLRM_HSTU/content_encoder.py b/recml/examples/DLRM_HSTU/content_encoder.py new file mode 100644 index 0000000..f4cc8b8 --- /dev/null +++ b/recml/examples/DLRM_HSTU/content_encoder.py @@ -0,0 +1,170 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of ContentEncoder for dense tensors.""" + +from typing import Dict, List, Optional + +import flax.linen as nn +from jax import numpy as jnp + + +class ContentEncoder(nn.Module): + """JAX/Flax implementation of ContentEncoder for dense tensors. + + This module concatenates input embeddings with additional features. It + handles two types of features: + 1. `additional_content_features`: Features available for the entire + sequence. + 2. `target_enrich_features`: Features available only for target items, with + a learned dummy embedding used as a placeholder for history items. + """ + input_embedding_dim: int + additional_content_features: Optional[Dict[str, int]] = None + target_enrich_features: Optional[Dict[str, int]] = None + + def setup(self) -> None: + self._additional_content_features_internal: Dict[str, int] = ( + self.additional_content_features + if self.additional_content_features is not None + else {} + ) + self._target_enrich_features_internal: Dict[str, int] = ( + self.target_enrich_features + if self.target_enrich_features is not None + else {} + ) + + self._target_enrich_dummy_embeddings = { + name: self.param( + f"target_enrich_dummy_param_{name}", + nn.initializers.normal(stddev=0.1), + (1, dim), # Shape is (1, feature_dim) for broadcasting + ) + for name, dim in self._target_enrich_features_internal.items() + } + + @property + def output_embedding_dim(self) -> int: + """The total dimension of the output embeddings after concatenation.""" + additional_dim = sum( + self.additional_content_features.values() + if self.additional_content_features + else [] + ) + enrich_dim = sum( + self.target_enrich_features.values() + if self.target_enrich_features + else [] + ) + return self.input_embedding_dim + additional_dim + enrich_dim + + @nn.compact + def __call__( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Forward pass for the ContentEncoder. + + Args: + max_uih_len: The length of the user interaction history (non-target part) + in the padded sequence. + seq_embeddings: The base embeddings for the sequence with shape + (batch_size, seq_len, input_embedding_dim). + seq_payloads: A dictionary mapping feature names to their tensors. - For + `additional_content_features`, shape is (batch_size, seq_len, + feature_dim). - For `target_enrich_features`, shape is (batch_size, + max_targets, feature_dim). + + Returns: + The concatenated content embeddings. + Shape: (batch_size, seq_len, output_embedding_dim). + """ + content_embeddings_list: List[jnp.ndarray] = [seq_embeddings] + + if self._additional_content_features_internal: + for x in self._additional_content_features_internal.keys(): + content_embeddings_list.append( + seq_payloads[x].astype(seq_embeddings.dtype) + ) + + if self._target_enrich_dummy_embeddings: + batch_size = seq_embeddings.shape[0] + + for name, param in self._target_enrich_dummy_embeddings.items(): + # If a feature is used for both additional content and target + # enrichment, the payload will contain the full sequence. We need to + # slice the target part. + if name in self._additional_content_features_internal: + full_sequence_feature = seq_payloads[name] + enrich_embeddings_target = full_sequence_feature[ + :, max_uih_len:, : + ].astype(seq_embeddings.dtype) + else: + # Otherwise, the payload contains only the target features. + enrich_embeddings_target = seq_payloads[name].astype( + seq_embeddings.dtype + ) + enrich_embeddings_uih = jnp.broadcast_to( + param, (batch_size, max_uih_len, param.shape[-1]) + ).astype(seq_embeddings.dtype) + + # Pad targets if necessary to match sequence length + num_targets = enrich_embeddings_target.shape[1] + num_history = max_uih_len + if num_history + num_targets < seq_embeddings.shape[1]: + padding_needed = seq_embeddings.shape[1] - ( + num_history + num_targets + ) + padding = jnp.zeros( + ( + batch_size, + padding_needed, + enrich_embeddings_target.shape[-1], + ), + dtype=enrich_embeddings_target.dtype, + ) + enrich_embeddings_target = jnp.concatenate( + [enrich_embeddings_target, padding], axis=1 + ) + + enrich_embeddings = jnp.concatenate( + [enrich_embeddings_uih, enrich_embeddings_target], axis=1 + ) + if enrich_embeddings.shape[1] < seq_embeddings.shape[1]: + padding = jnp.zeros( + ( + batch_size, + seq_embeddings.shape[1] - enrich_embeddings.shape[1], + enrich_embeddings.shape[2], + ), + dtype=enrich_embeddings.dtype, + ) + enrich_embeddings = jnp.concatenate( + [enrich_embeddings, padding], axis=1 + ) + content_embeddings_list.append(enrich_embeddings) + + if ( + not self._additional_content_features_internal + and not self._target_enrich_features_internal + ): + return seq_embeddings + else: + content_embeddings = jnp.concatenate( + content_embeddings_list, + axis=-1, + ) + return content_embeddings diff --git a/recml/examples/DLRM_HSTU/content_encoder_test.py b/recml/examples/DLRM_HSTU/content_encoder_test.py new file mode 100644 index 0000000..0c13e3e --- /dev/null +++ b/recml/examples/DLRM_HSTU/content_encoder_test.py @@ -0,0 +1,106 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl.testing import absltest +import jax +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder + + +class ContentEncoderTest(absltest.TestCase): + """Tests for JAX ContentEncoder.""" + + def test_forward_and_backward_pass(self) -> None: + """Verifies that the model's forward and backward passes execute without error.""" + batch_size = 2 + seq_len = 6 + num_targets = 2 + max_uih_len = seq_len - num_targets + input_embedding_dim = 32 + additional_embedding_dim = 64 + enrich_embedding_dim = 16 + + encoder = ContentEncoder( + input_embedding_dim=input_embedding_dim, + additional_content_features={ + "a0": additional_embedding_dim, + "a1": additional_embedding_dim, + }, + target_enrich_features={ + "t0": enrich_embedding_dim, + "t1": enrich_embedding_dim, + }, + ) + + key = jax.random.PRNGKey(42) + key, data_key, init_key = jax.random.split(key, 3) + + seq_embeddings = jax.random.normal( + data_key, (batch_size, seq_len, input_embedding_dim) + ) + seq_payloads = { + "a0": jax.random.normal( + data_key, (batch_size, seq_len, additional_embedding_dim) + ), + "a1": jax.random.normal( + data_key, (batch_size, seq_len, additional_embedding_dim) + ), + "t0": jax.random.normal( + data_key, (batch_size, num_targets, enrich_embedding_dim) + ), + "t1": jax.random.normal( + data_key, (batch_size, num_targets, enrich_embedding_dim) + ), + } + + params = encoder.init( + init_key, + max_uih_len, + seq_embeddings, + seq_payloads, + )["params"] + + content_embeddings = encoder.apply( + {"params": params}, + max_uih_len, + seq_embeddings, + seq_payloads, + ) + + expected_dim = ( + input_embedding_dim + + sum(encoder.additional_content_features.values()) + + sum(encoder.target_enrich_features.values()) + ) + self.assertEqual( + content_embeddings.shape, (batch_size, seq_len, expected_dim) + ) + + def loss_fn(p): + output = encoder.apply( + {"params": p}, + max_uih_len, + seq_embeddings, + seq_payloads, + ) + return jnp.sum(output) + + grads = jax.grad(loss_fn)(params) + + self.assertIsNotNone(grads) + self.assertIn("target_enrich_dummy_param_t0", grads) + self.assertIn("target_enrich_dummy_param_t1", grads) + + +if __name__ == "__main__": + absltest.main() diff --git a/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py b/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py new file mode 100644 index 0000000..fad1bd3 --- /dev/null +++ b/recml/examples/DLRM_HSTU/contextual_interleave_preprocessor.py @@ -0,0 +1,164 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of ContextualInterleavePreprocessor.""" + +from typing import Callable, Dict, Tuple + +from flax import linen as nn +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder +from recml.examples.DLRM_HSTU.contextualize_mlps import ContextualizedMLP +from recml.examples.DLRM_HSTU.preprocessors import get_contextual_input_embeddings +from recml.examples.DLRM_HSTU.preprocessors import InputPreprocessor + + +class ContextualInterleavePreprocessor(InputPreprocessor): + """A JAX/Flax implementation of the ContextualInterleavePreprocessor. + + This preprocessor orchestrates content encoding, action encoding, and + contextualization using parameterized MLPs, working on dense, padded tensors. + """ + + input_embedding_dim: int + output_embedding_dim: int + contextual_feature_to_max_length: Dict[str, int] + contextual_feature_to_min_uih_length: Dict[str, int] + content_encoder: ContentEncoder + content_contextualize_mlp_fn: Callable[[], ContextualizedMLP] + action_encoder: ActionEncoder + action_contextualize_mlp_fn: Callable[[], ContextualizedMLP] + pmlp_contextual_dropout_ratio: float = 0.0 + enable_interleaving: bool = False + + def setup(self): + self._max_contextual_seq_len = sum( + self.contextual_feature_to_max_length.values() + ) + + self._content_embedding_mlp = self.content_contextualize_mlp_fn() + self._action_embedding_mlp = self.action_contextualize_mlp_fn() + + if self._max_contextual_seq_len > 0: + self._batched_contextual_linear_weights = self.param( + "batched_contextual_linear_weights", + nn.initializers.xavier_uniform(), + ( + self._max_contextual_seq_len, + self.input_embedding_dim, + self.output_embedding_dim, + ), + ) + self._batched_contextual_linear_bias = self.param( + "batched_contextual_linear_bias", + nn.initializers.zeros, + (self._max_contextual_seq_len, self.output_embedding_dim), + ) + self._pmlp_dropout = nn.Dropout(rate=self.pmlp_contextual_dropout_ratio) + + def __call__( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + *, + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray] + ]: + batch_size, max_seq_len, _ = seq_embeddings.shape + + pmlp_contextual_embeddings = None + contextual_embeddings = None + if self._max_contextual_seq_len > 0: + contextual_input_embeddings = get_contextual_input_embeddings( + seq_mask=seq_mask, + seq_payloads=seq_payloads, + contextual_feature_to_max_length=self.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=self.contextual_feature_to_min_uih_length, + dtype=seq_embeddings.dtype, + ) + + pmlp_contextual_embeddings = self._pmlp_dropout( + contextual_input_embeddings, deterministic=deterministic + ) + + contextual_embeddings = jnp.einsum( + "bci,cio->bco", + contextual_input_embeddings.reshape( + batch_size, self._max_contextual_seq_len, self.input_embedding_dim + ), + self._batched_contextual_linear_weights, + ) + jnp.expand_dims(self._batched_contextual_linear_bias, axis=0) + + # Content Embeddings + content_embeddings = self.content_encoder( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_payloads=seq_payloads, + ) + content_embeddings = self._content_embedding_mlp( + seq_embeddings=content_embeddings, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # Action Embeddings + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + indices = jnp.arange(max_seq_len) + start_target_idx = jnp.expand_dims(seq_lengths - num_targets, axis=1) + is_target_mask = (indices >= start_target_idx) & seq_mask + + action_embeddings = self.action_encoder( + seq_payloads=seq_payloads, + is_target_mask=is_target_mask, + ) + action_embeddings = self._action_embedding_mlp( + seq_embeddings=action_embeddings, + contextual_embeddings=pmlp_contextual_embeddings, + ) + + # Combine + output_seq_embeddings = content_embeddings + action_embeddings + output_seq_embeddings *= jnp.expand_dims(seq_mask, axis=-1) + output_mask = seq_mask + output_timestamps = seq_timestamps + + # Prepend contextual embeddings + if self._max_contextual_seq_len > 0: + output_seq_embeddings = jnp.concatenate( + [contextual_embeddings, output_seq_embeddings], axis=1 + ) + contextual_mask = jnp.ones( + (batch_size, self._max_contextual_seq_len), dtype=jnp.bool_ + ) + output_mask = jnp.concatenate([contextual_mask, seq_mask], axis=1) + + contextual_timestamps = jnp.zeros( + (batch_size, self._max_contextual_seq_len), + dtype=seq_timestamps.dtype, + ) + output_timestamps = jnp.concatenate( + [contextual_timestamps, seq_timestamps], axis=1 + ) + + return ( + output_seq_embeddings, + output_mask, + output_timestamps, + num_targets, + seq_payloads, + ) diff --git a/recml/examples/DLRM_HSTU/contextualize_mlps.py b/recml/examples/DLRM_HSTU/contextualize_mlps.py new file mode 100644 index 0000000..d694824 --- /dev/null +++ b/recml/examples/DLRM_HSTU/contextualize_mlps.py @@ -0,0 +1,179 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains Flax modules for contextualized MLPs used in DLRM-HSTU.""" + +from typing import Optional + +from flax import linen as nn +import jax.numpy as jnp + + +class SwishLayerNorm(nn.Module): + """Custom module for Swish(LayerNorm(x)) which is x * sigmoid(LayerNorm(x)). + + This mimics the SwishLayerNorm class in the PyTorch implementation. + """ + + epsilon: float = 1e-5 + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Computes Swish(LayerNorm(x)). + + Args: + x: Input tensor. + + Returns: + The output tensor. + """ + normed_x = nn.LayerNorm(epsilon=self.epsilon, name="layernorm")(x) + return x * nn.sigmoid(normed_x) + + +class ContextualizedMLP(nn.Module): + """Abstract base class for contextualized MLPs. + + JAX/Flax doesn't strictly require this, but it is included for structural + parity with the PyTorch version. + + This module assumes dense inputs, where ragged tensors have been padded. + """ + + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Forward pass for contextualized MLPs. + + Args: + seq_embeddings: Dense tensor of shape (B, N, D_in). + contextual_embeddings: Dense tensor of shape (B, D_ctx). + + Returns: + Output tensor. + """ + raise NotImplementedError() + + +class SimpleContextualizedMLP(ContextualizedMLP): + """A simple MLP applied to sequential embeddings, ignoring contextual ones. + + This module is analogous to the PyTorch version and works on dense tensors. + """ + + sequential_input_dim: int + sequential_output_dim: int + hidden_dim: int + + @nn.compact + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Applies a simple MLP to the sequence embeddings. + + Args: + seq_embeddings: Dense tensor of shape (B, N, sequential_input_dim). + contextual_embeddings: Ignored. + + Returns: + Output tensor of shape (B, N, sequential_output_dim). + """ + x = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="mlp_0", + )(seq_embeddings) + x = SwishLayerNorm(name="mlp_1")(x) + + x = nn.Dense( + features=self.sequential_output_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="mlp_2", + )(x) + x = nn.LayerNorm(name="mlp_3")(x) + return x + + +class ParameterizedContextualizedMLP(ContextualizedMLP): + """An MLP whose weights are parameterized by contextual embeddings. + + This module is analogous to the PyTorch version and works on dense tensors. + """ + + contextual_embedding_dim: int + sequential_input_dim: int + sequential_output_dim: int + hidden_dim: int + + @nn.compact + def __call__( + self, + seq_embeddings: jnp.ndarray, + contextual_embeddings: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Applies a parameterized MLP to the sequence embeddings. + + Args: + seq_embeddings: Dense tensor of shape (B, N, sequential_input_dim). + contextual_embeddings: Dense tensor of shape + (B, contextual_embedding_dim). + + Returns: + Output tensor of shape (B, N, sequential_output_dim). + """ + if contextual_embeddings is None: + raise ValueError( + "contextual_embeddings cannot be None for " + "ParameterizedContextualizedMLP" + ) + + shared_input = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="dense_features_compress" + )(contextual_embeddings) + + attn_raw_weights_flat = nn.Dense( + features=self.sequential_input_dim * self.sequential_output_dim, + name="attn_raw_weights_0" + )(shared_input) + + batch_size = contextual_embeddings.shape[0] + attn_weights_unnorm = attn_raw_weights_flat.reshape( + batch_size, self.sequential_input_dim, self.sequential_output_dim + ) + + attn_weights = nn.LayerNorm( + feature_axes=(-2, -1), + name="attn_weights_norm" + )(attn_weights_unnorm) + + res_x = nn.Dense( + features=self.hidden_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="res_weights_0" + )(shared_input) + res_x = SwishLayerNorm(name="res_weights_1")(res_x) + bias = nn.Dense( + features=self.sequential_output_dim, + kernel_init=nn.initializers.xavier_uniform(), + name="res_weights_2" + )(res_x) + + bmm_out = jnp.matmul(seq_embeddings, attn_weights) + bias_broadcast = jnp.expand_dims(bias, axis=1) + return bmm_out + bias_broadcast diff --git a/recml/examples/DLRM_HSTU/dlrm_hstu.py b/recml/examples/DLRM_HSTU/dlrm_hstu.py new file mode 100644 index 0000000..1da975c --- /dev/null +++ b/recml/examples/DLRM_HSTU/dlrm_hstu.py @@ -0,0 +1,519 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of DLRM-HSTU.""" + +from dataclasses import dataclass +from dataclasses import field +from functools import partial +import logging +from typing import Any, Dict, List, Optional, Tuple +from etils import epy +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder +from recml.examples.DLRM_HSTU.content_encoder import ContentEncoder +from recml.examples.DLRM_HSTU.contextual_interleave_preprocessor import ContextualInterleavePreprocessor +from recml.examples.DLRM_HSTU.contextualize_mlps import ContextualizedMLP +from recml.examples.DLRM_HSTU.contextualize_mlps import ParameterizedContextualizedMLP +from recml.examples.DLRM_HSTU.contextualize_mlps import SimpleContextualizedMLP +from recml.examples.DLRM_HSTU.hstu_transducer import HSTUTransducer +from recml.examples.DLRM_HSTU.multitask_module import DefaultMultitaskModule +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig +from recml.examples.DLRM_HSTU.positional_encoder import HSTUPositionalEncoder +from recml.examples.DLRM_HSTU.postprocessors import L2NormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import LayerNormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import TimestampLayerNormPostprocessor +from recml.examples.DLRM_HSTU.preprocessors import SwishLayerNorm +from recml.examples.DLRM_HSTU.stu import STULayer +from recml.examples.DLRM_HSTU.stu import STULayerConfig +from recml.examples.DLRM_HSTU.stu import STUStack +from recml.layers.linen import sparsecore + +with epy.lazy_imports(): + # pylint: disable=g-import-not-at-top + from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + # pylint: enable=g-import-not-at-top + +logger = logging.getLogger(__name__) + +Dtype = Any +Array = jnp.ndarray + + +@dataclass +class EmbeddingConfig: + """Simplified embedding config for JAX.""" + + name: str + num_embeddings: int + embedding_dim: int + + +@dataclass +class DlrmHSTUConfig: + """Configuration for DLRM-HSTU model.""" + + sparsecore_config: sparsecore.SparsecoreConfig + max_seq_len: int = 2056 + max_num_candidates: int = 10 + max_num_candidates_inference: int = 5 + hstu_num_heads: int = 1 + hstu_attn_linear_dim: int = 256 + hstu_attn_qk_dim: int = 128 + hstu_attn_num_layers: int = 12 + hstu_embedding_table_dim: int = 192 + hstu_preprocessor_hidden_dim: int = 256 + hstu_transducer_embedding_dim: int = 256 # changed from 0 + hstu_group_norm: bool = False + hstu_input_dropout_ratio: float = 0.2 + hstu_linear_dropout_rate: float = 0.2 + hstu_max_attn_len: int = 0 + contextual_feature_to_max_length: Dict[str, int] = field(default_factory=dict) + contextual_feature_to_min_uih_length: Dict[str, int] = field( + default_factory=dict + ) + additional_content_features: Optional[Dict[str, int]] = None + target_enrich_features: Optional[Dict[str, int]] = None + pmlp_contextual_dropout_ratio: float = 0.0 + candidates_weight_feature_name: str = "" + candidates_watchtime_feature_name: str = "" + candidates_querytime_feature_name: str = "" + watchtime_feature_name: str = "" + causal_multitask_weights: float = 0.2 + multitask_configs: List[TaskConfig] = field(default_factory=list) + user_embedding_feature_names: List[str] = field(default_factory=list) + item_embedding_feature_names: List[str] = field(default_factory=list) + uih_post_id_feature_name: str = "" + uih_action_time_feature_name: str = "" + uih_weight_feature_name: str = "" + hstu_uih_feature_names: List[str] = field(default_factory=list) + hstu_candidate_feature_names: List[str] = field(default_factory=list) + merge_uih_candidate_feature_mapping: List[Tuple[str, str]] = field( + default_factory=list + ) + action_weights: Optional[List[int]] = None + watchtime_to_action_thresholds_and_weights: Optional[ + List[Tuple[int, int]] + ] = None + enable_postprocessor: bool = True + use_layer_norm_postprocessor: bool = False + + +def _get_supervision_labels_and_weights( + supervision_bitmasks: jnp.ndarray, + watchtime_sequence: jnp.ndarray, + task_configs: List[TaskConfig], + candidate_padding_mask: jnp.ndarray, +) -> Tuple[Dict[str, jnp.ndarray], Dict[str, jnp.ndarray]]: + """Computes supervision labels and weights for multitask learning.""" + supervision_labels: Dict[str, jnp.ndarray] = {} + supervision_weights: Dict[str, jnp.ndarray] = {} + for task in task_configs: + if task.task_type == MultitaskTaskType.REGRESSION: + supervision_labels[task.task_name] = watchtime_sequence.astype( + jnp.float32 + ) + elif task.task_type == MultitaskTaskType.BINARY_CLASSIFICATION: + supervision_labels[task.task_name] = ( + jnp.bitwise_and(supervision_bitmasks, task.task_weight) > 0 + ).astype(jnp.float32) + else: + raise RuntimeError("Unsupported MultitaskTaskType") + supervision_weights[task.task_name] = candidate_padding_mask.astype( + jnp.float32 + ) + return supervision_labels, supervision_weights + + +class PredictionMLP(nn.Module): + """MLP for multitask prediction head.""" + + hidden_dim: int + num_tasks: int + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + x = nn.Dense( + features=self.hidden_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = SwishLayerNorm(dtype=self.dtype)(x) + x = nn.Dense( + features=self.num_tasks, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + return x + + +class ItemMLP(nn.Module): + """MLP for processing item embeddings.""" + + hidden_dim: int + output_dim: int + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + x = nn.Dense( + features=self.hidden_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = SwishLayerNorm(dtype=self.dtype)(x) + x = nn.Dense( + features=self.output_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + )(x) + x = nn.LayerNorm(dtype=self.dtype)(x) + return x + + +class DlrmHSTU(nn.Module): + """JAX/Flax implementation of DLRM with HSTU user encoder. + + Operates on dense tensors. + """ + + hstu_configs: DlrmHSTUConfig + dtype: Dtype = jnp.float32 + mesh: jax.sharding.AbstractMesh | None = None + + def setup(self): + self._embedding_layer = sparsecore.SparsecoreEmbed( + self.hstu_configs.sparsecore_config, + mesh=self.mesh, + ) + self._multitask_configs: List[TaskConfig] = ( + self.hstu_configs.multitask_configs + ) + + self._multitask_module = DefaultMultitaskModule( + task_configs=self._multitask_configs, + embedding_dim=self.hstu_configs.hstu_transducer_embedding_dim, + prediction_fn=lambda in_dim, num_tasks: PredictionMLP( + hidden_dim=512, num_tasks=num_tasks, dtype=self.dtype + ), + causal_multitask_weights=self.hstu_configs.causal_multitask_weights, + ) + + hstu_config = self.hstu_configs + + content_encoder = ContentEncoder( + input_embedding_dim=hstu_config.hstu_embedding_table_dim, + additional_content_features=hstu_config.additional_content_features, + target_enrich_features=hstu_config.target_enrich_features, + ) + + action_encoder = ActionEncoder( + action_embedding_dim=hstu_config.hstu_transducer_embedding_dim, + action_feature_name=hstu_config.uih_weight_feature_name, + action_weights=hstu_config.action_weights, + watchtime_feature_name=hstu_config.watchtime_feature_name, + watchtime_to_action_thresholds_and_weights=hstu_config.watchtime_to_action_thresholds_and_weights, + ) + + contextual_embedding_dim = sum( + hstu_config.contextual_feature_to_max_length.values() + ) * hstu_config.hstu_embedding_table_dim + + def mlp_fn( + sequential_input_dim: int, + ) -> ContextualizedMLP: + if contextual_embedding_dim > 0: + return ParameterizedContextualizedMLP( + contextual_embedding_dim=contextual_embedding_dim, + sequential_input_dim=sequential_input_dim, + sequential_output_dim=hstu_config.hstu_transducer_embedding_dim, + hidden_dim=hstu_config.hstu_preprocessor_hidden_dim, + ) + else: + return SimpleContextualizedMLP( + sequential_input_dim=sequential_input_dim, + sequential_output_dim=hstu_config.hstu_transducer_embedding_dim, + hidden_dim=hstu_config.hstu_preprocessor_hidden_dim, + ) + + preprocessor = ContextualInterleavePreprocessor( + input_embedding_dim=hstu_config.hstu_embedding_table_dim, + output_embedding_dim=hstu_config.hstu_transducer_embedding_dim, + contextual_feature_to_max_length=hstu_config.contextual_feature_to_max_length, + contextual_feature_to_min_uih_length=hstu_config.contextual_feature_to_min_uih_length, + content_encoder=content_encoder, + content_contextualize_mlp_fn=partial( + mlp_fn, sequential_input_dim=content_encoder.output_embedding_dim + ), + action_encoder=action_encoder, + action_contextualize_mlp_fn=partial( + mlp_fn, sequential_input_dim=action_encoder.output_embedding_dim + ), + pmlp_contextual_dropout_ratio=hstu_config.pmlp_contextual_dropout_ratio, + ) + + contextual_seq_len = sum( + hstu_config.contextual_feature_to_max_length.values() + ) + positional_encoder = HSTUPositionalEncoder( + num_position_buckets=8192, + num_time_buckets=2048, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + contextual_seq_len=contextual_seq_len, + ) + + if hstu_config.enable_postprocessor: + if hstu_config.use_layer_norm_postprocessor: + postproc_cls = partial( + LayerNormPostprocessor, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + eps=1e-5, + dtype=self.dtype, + ) + else: + postproc_cls = partial( + TimestampLayerNormPostprocessor, + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + time_duration_features=[(60 * 60, 24), (24 * 60 * 60, 7)], + eps=1e-5, + dtype=self.dtype, + ) + else: + postproc_cls = L2NormPostprocessor + + stu_configs = [] + for _ in range(hstu_config.hstu_attn_num_layers): + stu_layer_config = STULayerConfig( + embedding_dim=hstu_config.hstu_transducer_embedding_dim, + num_heads=hstu_config.hstu_num_heads, + hidden_dim=hstu_config.hstu_attn_linear_dim, + attention_dim=hstu_config.hstu_attn_qk_dim, + output_dropout_ratio=hstu_config.hstu_linear_dropout_rate, + causal=True, + target_aware=True, + use_group_norm=hstu_config.hstu_group_norm, + contextual_seq_len=contextual_seq_len, + max_attn_len=hstu_config.hstu_max_attn_len, + ) + stu_configs.append(stu_layer_config) + stu_module = STUStack(configs=stu_configs) + + self._hstu_transducer = HSTUTransducer( + stu_module=stu_module, + input_preprocessor=preprocessor, + output_postprocessor_cls=postproc_cls, + input_dropout_ratio=hstu_config.hstu_input_dropout_ratio, + positional_encoder=positional_encoder, + return_full_embeddings=False, + listwise=False, + ) + + self._item_embedding_mlp = ItemMLP( + hidden_dim=512, + output_dim=hstu_config.hstu_transducer_embedding_dim, + dtype=self.dtype, + ) + + def _concat_features( + self, uih_tensor: Array, cand_tensor: Array + ) -> Array: + """Concatenates dense UIH and candidate tensors along sequence dim.""" + return jnp.concatenate([uih_tensor, cand_tensor], axis=1) + + def _construct_payload( + self, + uih_features: Dict[str, Array], + cand_features: Dict[str, Array], + embeddings: Any, + ) -> Dict[str, Array]: + """Constructs payload dictionary for HSTUTransducer.""" + payload = {} + for name in self.hstu_configs.contextual_feature_to_max_length: + if name in embeddings: + payload[name] = embeddings[name] + elif name in uih_features: # non-embedding contextual feature + payload[name] = uih_features[name] + elif name in cand_features: + payload[name] = cand_features[name] + + for ( + uih_name, + cand_name, + ) in self.hstu_configs.merge_uih_candidate_feature_mapping: + is_sc_feature = ( + uih_name in self.hstu_configs.sparsecore_config.specs + ) + # Handle embedding features that need to be in the payload. + if uih_name in embeddings: + if uih_name not in payload: + if is_sc_feature: + payload[uih_name] = embeddings[uih_name] + else: + payload[uih_name] = self._concat_features( + embeddings[uih_name], embeddings[cand_name] + ) + # Handle non-embedding features that need to be merged. + elif uih_name in uih_features and cand_name in cand_features: + if uih_name not in payload: + payload[uih_name] = uih_features[uih_name] + + # Handle features that only exist for candidates (for target enrichment) + if self.hstu_configs.target_enrich_features: + for feat_name in self.hstu_configs.target_enrich_features: + if feat_name in embeddings and feat_name not in payload: + payload[feat_name] = embeddings[feat_name] + return payload + + def __call__( + self, + features: Dict[str, Array], + uih_lengths: Array, + num_candidates: Array, + *, + deterministic: bool, + decode: bool = False, + ) -> Tuple[ + Array, + Array, + Dict[str, Array], + Optional[Array], + Optional[Array], + Optional[Array], + ]: + """Forward pass for DLRM-HSTU. + + Args: + features: Dict of dense feature tensors. When using SparseCore, this + dictionary is the output of SparsecorePreprocessor and must contain + concatenated sequences for sparse features, along with other dense + features like 'action_time', 'query_time' etc. + uih_lengths: Length of UIH sequences (B,). + num_candidates: Number of candidates per example (B,). + deterministic: If true, disable dropout. + + Returns: + Tuple of (user_embeddings, item_embeddings, aux_losses, + preds, labels, weights). + """ + max_uih_len = features[ + self.hstu_configs.uih_action_time_feature_name + ].shape[1] + max_candidates = features[ + self.hstu_configs.candidates_querytime_feature_name + ].shape[1] + embeddings: Any = self._embedding_layer(features) + + candidate_padding_mask = jnp.arange(max_candidates) < num_candidates[:, None] + + if self.hstu_configs.uih_post_id_feature_name not in embeddings: + raise ValueError( + "Post ID feature " + f"{self.hstu_configs.uih_post_id_feature_name} not found in " + "merged embeddings." + ) + cand_item_embeddings_for_mlp = jnp.concatenate( + [ + embeddings[k][:, max_uih_len:, :] + for k in self.hstu_configs.item_embedding_feature_names + ], + axis=-1, + ) + item_embeddings_candidates = self._item_embedding_mlp( + cand_item_embeddings_for_mlp + ) + + payload = self._construct_payload( + features, features, embeddings + ) + hstu_seq_lengths = uih_lengths + num_candidates + hstu_seq_embeddings = embeddings[ + self.hstu_configs.uih_post_id_feature_name + ] + + candidate_querytime_feature_name = ( + self.hstu_configs.candidates_querytime_feature_name + ) + hstu_seq_timestamps = self._concat_features( + features[self.hstu_configs.uih_action_time_feature_name], + features[candidate_querytime_feature_name], + ) + + user_embeddings_candidates, _ = self._hstu_transducer( + max_uih_len=max_uih_len, + max_targets=max_candidates, + total_uih_len=0, # Not used in dense tensor implementation + total_targets=0, # Not used in dense tensor implementation + seq_lengths=hstu_seq_lengths, + seq_embeddings=hstu_seq_embeddings, + seq_timestamps=hstu_seq_timestamps, + num_targets=num_candidates, + seq_payloads=payload, + deterministic=deterministic, + decode=decode, + ) + + supervision_bitmasks = features[ + self.hstu_configs.candidates_weight_feature_name + ] + watchtime_sequence = features[ + self.hstu_configs.candidates_watchtime_feature_name + ] + supervision_labels, supervision_weights = ( + _get_supervision_labels_and_weights( + supervision_bitmasks, + watchtime_sequence, + self._multitask_configs, + candidate_padding_mask, + ) + ) + + # The HSTU transducer returns embeddings for the full sequence, with + # non-candidate parts masked. We need to slice out the candidate parts + # to match the shape of the item embeddings. + user_embeddings_candidates = user_embeddings_candidates[ + :, -max_candidates:, : + ] + + mt_target_preds, mt_target_labels, mt_target_weights, mt_losses = ( + self._multitask_module( + encoded_user_embeddings=user_embeddings_candidates, + item_embeddings=item_embeddings_candidates, + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + deterministic=deterministic, + ) + ) + + aux_losses: Dict[str, Array] = {} + if not deterministic and mt_losses is not None: + for i, task in enumerate(self._multitask_configs): + aux_losses[task.task_name] = mt_losses[i] + + return ( + user_embeddings_candidates, + item_embeddings_candidates, + aux_losses, + mt_target_preds, + mt_target_labels, + mt_target_weights, + ) diff --git a/recml/examples/DLRM_HSTU/dlrm_hstu_test.py b/recml/examples/DLRM_HSTU/dlrm_hstu_test.py new file mode 100644 index 0000000..a46b599 --- /dev/null +++ b/recml/examples/DLRM_HSTU/dlrm_hstu_test.py @@ -0,0 +1,366 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import tree_util +import jax.numpy as jnp +import numpy as np +import optax +from recml.core.training import partitioning +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTU +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTUConfig +from recml.examples.DLRM_HSTU.dlrm_hstu import EmbeddingConfig +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig +from recml.layers.linen import sparsecore +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + + +def get_sparsecore_config(embed_dim, post_id_vocab, cat_feat_vocab, seq_len): + """Returns a SparsecoreConfig for DLRM-HSTU.""" + post_id_spec = sparsecore.EmbeddingSpec( + input_dim=post_id_vocab, + embedding_dim=embed_dim, + max_sequence_length=seq_len, + ) + cat_feat_spec = sparsecore.EmbeddingSpec( + input_dim=cat_feat_vocab, + embedding_dim=embed_dim, + max_sequence_length=seq_len, + ) + return sparsecore.SparsecoreConfig( + specs={ + 'post_id': post_id_spec, + 'cat_feat': cat_feat_spec, + }, + optimizer=embedding_spec.AdagradOptimizerSpec(learning_rate=0.01), + ) + + +class DlrmHstuTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.batch_size = 16 + self.max_uih_len = 4 + self.max_candidates = 2 + self.embed_dim = 128 + self.hstu_dim = 128 + self.post_id_vocab = 20000 + self.cat_feat_vocab = 20000 + + sc_config = get_sparsecore_config( + self.embed_dim, + self.post_id_vocab, + self.cat_feat_vocab, + seq_len=self.max_uih_len + self.max_candidates, + ) + + self.config = DlrmHSTUConfig( + sparsecore_config=sc_config, + max_seq_len=self.max_uih_len + self.max_candidates, + hstu_embedding_table_dim=self.embed_dim, + hstu_transducer_embedding_dim=self.hstu_dim, + hstu_preprocessor_hidden_dim=8, + hstu_attn_num_layers=4, + hstu_attn_linear_dim=8, + hstu_attn_qk_dim=8, + item_embedding_feature_names=['post_id', 'cat_feat'], + user_embedding_feature_names=['post_id'], + uih_post_id_feature_name='post_id', + uih_action_time_feature_name='action_time', + candidates_querytime_feature_name='query_time', + candidates_weight_feature_name='weight', + candidates_watchtime_feature_name='watch_time', + uih_weight_feature_name='action_type', + action_weights=[1, 2, 4], + merge_uih_candidate_feature_mapping=[ + ('post_id', 'post_id'), + ('cat_feat', 'cat_feat'), + ('action_type', 'action_type'), + ], + multitask_configs=[ + TaskConfig('CTR', 1, MultitaskTaskType.BINARY_CLASSIFICATION), + TaskConfig('WT', 2, MultitaskTaskType.REGRESSION), + ], + contextual_feature_to_max_length={}, + additional_content_features={'cat_feat': self.embed_dim}, + target_enrich_features={'cat_feat': self.embed_dim}, + ) + if jax.devices()[0].platform == 'tpu': + self.mesh = jax.sharding.Mesh(np.array(jax.devices()), ('data',)) + else: + self.mesh = None + + def _get_mock_data(self, key): + k1, k2, k3, k4, k5, k6, k7, k8, k9 = jax.random.split(key, 9) + uih_features = { + 'post_id': jax.random.randint( + k1, (self.batch_size, self.max_uih_len), 0, self.post_id_vocab + ), + 'cat_feat': jax.random.randint( + k2, (self.batch_size, self.max_uih_len), 0, self.cat_feat_vocab + ), + 'action_time': jax.random.randint( + k3, (self.batch_size, self.max_uih_len), 0, 1000 + ), + 'action_type': jax.random.randint( + k7, (self.batch_size, self.max_uih_len), 0, 8 + ), + } + candidate_features = { + 'post_id': jax.random.randint( + k1, (self.batch_size, self.max_candidates), 0, self.post_id_vocab + ), + 'cat_feat': jax.random.randint( + k2, (self.batch_size, self.max_candidates), 0, self.cat_feat_vocab + ), + 'query_time': jax.random.randint( + k4, (self.batch_size, self.max_candidates), 1000, 2000 + ), + 'weight': jax.random.randint( + k5, (self.batch_size, self.max_candidates), 0, 2 + ), # for CTR bitmask + 'watch_time': jax.random.randint( + k6, (self.batch_size, self.max_candidates), 0, 100 + ), # for WT regression + 'action_type': jax.random.randint( + k7, (self.batch_size, self.max_candidates), 0, 8 + ), + } + uih_lengths = jax.random.randint( + k8, (self.batch_size,), 1, self.max_uih_len + 1 + ).astype(jnp.int32) + num_candidates = jax.random.randint( + k9, (self.batch_size,), 1, self.max_candidates + 1 + ).astype(jnp.int32) + return uih_features, candidate_features, uih_lengths, num_candidates + + @parameterized.named_parameters( + ('train', False), + ('eval', True), + ) + def test_forward_pass(self, deterministic): + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + key = jax.random.PRNGKey(0) + prng_keys = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, mesh=self.mesh) + uih_features, candidate_features, uih_lengths, num_candidates = ( + self._get_mock_data(key) + ) + + features = uih_features | candidate_features + features['post_id'] = np.concatenate( + [uih_features['post_id'], candidate_features['post_id']], axis=1 + ) + features['cat_feat'] = np.concatenate( + [uih_features['cat_feat'], candidate_features['cat_feat']], axis=1 + ) + features['action_type'] = np.concatenate( + [uih_features['action_type'], candidate_features['action_type']], axis=1 + ) + + preprocessor = sparsecore.SparsecorePreprocessor( + self.config.sparsecore_config, self.batch_size + ) + sc_features = preprocessor(features) + + variables = model.init( + {'params': prng_keys[0], 'dropout': prng_keys[1]}, + sc_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + ) + + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + variables, + sc_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + rngs={'dropout': prng_keys[2]} if not deterministic else None, + ) + + num_tasks = len(self.config.multitask_configs) + expected_user_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(user_emb.shape, expected_user_emb_shape) + expected_item_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(item_emb.shape, expected_item_emb_shape) + self.assertEqual( + preds.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + if not deterministic: + self.assertNotEmpty(aux_losses) + self.assertEqual( + labels.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + self.assertEqual( + weights.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + else: + self.assertEmpty(aux_losses) + self.assertIsNone(labels) + self.assertIsNone(weights) + + def test_backward_pass_and_training(self): + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + key = jax.random.PRNGKey(1) + init_key, data_key, train_key = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, mesh=self.mesh) + uih_features, candidate_features, uih_lengths, num_candidates = ( + self._get_mock_data(data_key) + ) + + features = uih_features | candidate_features + features['post_id'] = np.concatenate( + [uih_features['post_id'], candidate_features['post_id']], axis=1 + ) + features['cat_feat'] = np.concatenate( + [uih_features['cat_feat'], candidate_features['cat_feat']], axis=1 + ) + features['action_type'] = np.concatenate( + [uih_features['action_type'], candidate_features['action_type']], axis=1 + ) + + preprocessor = sparsecore.SparsecorePreprocessor( + self.config.sparsecore_config, self.batch_size + ) + sc_features = preprocessor(features) + + variables = model.init( + {'params': init_key, 'dropout': train_key}, + sc_features, + uih_lengths, + num_candidates, + deterministic=False, + ) + params = variables['params'] + cache = variables['cache'] + + logging.info( + 'Model parameter shapes: %s', + tree_util.tree_map(lambda x: x.shape, params), + ) + logging.info('Model parameters: %s', params) + + optimizer = optax.adam(learning_rate=1e-3) + opt_state = optimizer.init(params) + + def loss_fn(params, dropout_key): + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + {'params': params, 'cache': cache}, + sc_features, + uih_lengths, + num_candidates, + deterministic=False, + rngs={'dropout': dropout_key}, + ) + return ( + user_emb.sum() + + item_emb.sum() + + preds.sum() + + sum(val.sum() for val in aux_losses.values()) + ) + + @jax.jit + def train_step(params, opt_state, dropout_key): + loss, grads = jax.value_and_grad(loss_fn)(params, dropout_key) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss + + logging.info('Starting training loop...') + for i in range(10): + step_key, train_key = jax.random.split(train_key) + params, opt_state, loss = train_step(params, opt_state, step_key) + logging.info('Step %d, Loss: %f', i, loss) + + self.assertIsNotNone(params) + + def test_dlrm_hstu_with_sparsecore(self): + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + key = jax.random.PRNGKey(0) + prng_keys = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, mesh=self.mesh) + uih_features, candidate_features, uih_lengths, num_candidates = ( + self._get_mock_data(key) + ) + + features = uih_features | candidate_features + features['post_id'] = np.concatenate( + [uih_features['post_id'], candidate_features['post_id']], axis=1 + ) + features['cat_feat'] = np.concatenate( + [uih_features['cat_feat'], candidate_features['cat_feat']], axis=1 + ) + features['action_type'] = np.concatenate( + [uih_features['action_type'], candidate_features['action_type']], axis=1 + ) + + preprocessor = sparsecore.SparsecorePreprocessor( + self.config.sparsecore_config, self.batch_size + ) + sc_features = preprocessor(features) + + variables = model.init( + {'params': prng_keys[0], 'dropout': prng_keys[1]}, + sc_features, + uih_lengths, + num_candidates, + deterministic=True, + ) + + user_emb, item_emb, _, preds, _, _ = model.apply( + variables, + sc_features, + uih_lengths, + num_candidates, + deterministic=True, + rngs={'dropout': prng_keys[2]}, + ) + num_tasks = len(self.config.multitask_configs) + expected_user_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(user_emb.shape, expected_user_emb_shape) + expected_item_emb_shape = ( + self.batch_size, + self.max_candidates, + self.hstu_dim, + ) + self.assertEqual(item_emb.shape, expected_item_emb_shape) + self.assertEqual( + preds.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/recml/examples/DLRM_HSTU/hstu_transducer.py b/recml/examples/DLRM_HSTU/hstu_transducer.py new file mode 100644 index 0000000..0ff41f6 --- /dev/null +++ b/recml/examples/DLRM_HSTU/hstu_transducer.py @@ -0,0 +1,242 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX/Flax implementation of HSTUTransducer for dense tensors.""" + +import logging +from typing import Dict, Optional, Tuple, Type + +import flax.linen as nn +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.positional_encoder import HSTUPositionalEncoder +from recml.examples.DLRM_HSTU.postprocessors import L2NormPostprocessor +from recml.examples.DLRM_HSTU.postprocessors import OutputPostprocessor +from recml.examples.DLRM_HSTU.preprocessors import InputPreprocessor +from recml.examples.DLRM_HSTU.stu import STUStack + + +logger = logging.getLogger(__name__) + + +class HSTUTransducer(nn.Module): + """JAX/Flax implementation of the HSTU Transducer module, using dense tensors. + + This implementation mirrors structure but replaces jagged tensor operations + with dense tensor operations using masking. + """ + + stu_module: STUStack + input_preprocessor: InputPreprocessor + output_postprocessor_cls: Type[OutputPostprocessor] = L2NormPostprocessor + input_dropout_ratio: float = 0.0 + positional_encoder: Optional[HSTUPositionalEncoder] = None + return_full_embeddings: bool = False + listwise: bool = False + + def setup(self): + self._output_postprocessor: OutputPostprocessor = ( + self.output_postprocessor_cls() + ) + self._input_dropout = nn.Dropout(rate=self.input_dropout_ratio) + + def _preprocess( + self, + max_uih_len: int, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + is_training: bool, + ) -> Tuple[ + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + Dict[str, jnp.ndarray], + ]: + """Preprocesses the input sequence embeddings.""" + ( + output_seq_embeddings, + output_seq_mask, + output_seq_timestamps, + output_num_targets, + output_seq_payloads, + ) = self.input_preprocessor( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_mask=seq_mask, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + deterministic=not is_training, + ) + + output_seq_lengths = jnp.sum(output_seq_mask, axis=1, dtype=jnp.int32) + + if self.positional_encoder is not None: + output_seq_embeddings = self.positional_encoder( + max_seq_len=output_seq_embeddings.shape[1], + seq_lengths=output_seq_lengths, + seq_timestamps=output_seq_timestamps, + seq_embeddings=output_seq_embeddings, + num_targets=( + None if self.listwise and is_training else output_num_targets + ), + ) + + output_seq_embeddings = self._input_dropout( + output_seq_embeddings, deterministic=not is_training + ) + + return ( + output_seq_embeddings, + output_seq_mask, + output_seq_timestamps, + output_num_targets, + output_seq_payloads, + ) + + def _hstu_compute( + self, + seq_embeddings: jnp.ndarray, + num_targets: jnp.ndarray, + is_training: bool, + decode: bool = False, + ) -> jnp.ndarray: + """Computes the HSTU embeddings.""" + seq_embeddings = self.stu_module( + x=seq_embeddings, + num_targets=None if self.listwise and is_training else num_targets, + deterministic=not is_training, + decode=decode, + ) + return seq_embeddings + + def _postprocess( + self, + seq_embeddings: jnp.ndarray, + seq_mask: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + ) -> Tuple[Optional[jnp.ndarray], jnp.ndarray]: + """Postprocesses the output sequence embeddings.""" + if self.return_full_embeddings: + seq_embeddings = self._output_postprocessor( + seq_embeddings=seq_embeddings, + seq_timestamps=seq_timestamps, + seq_payloads=seq_payloads, + ) + + batch_size, max_seq_len, embedding_dim = seq_embeddings.shape + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + indices = jnp.arange(max_seq_len) + start_target_idx = seq_lengths - num_targets + candidate_mask = (indices >= start_target_idx[:, jnp.newaxis]) & ( + indices < seq_lengths[:, jnp.newaxis] + ) + + candidate_embeddings_masked = ( + seq_embeddings * candidate_mask[..., jnp.newaxis] + ) + candidate_timestamps_masked = seq_timestamps * candidate_mask + + if self.input_preprocessor.interleave_targets(): + raise NotImplementedError( + "Interleaved targets not supported in dense post-processing yet." + ) + + if not self.return_full_embeddings: + candidate_embeddings = self._output_postprocessor( + seq_embeddings=candidate_embeddings_masked, + seq_timestamps=candidate_timestamps_masked, + seq_payloads=seq_payloads, + ) + candidate_embeddings = ( + candidate_embeddings * candidate_mask[..., jnp.newaxis] + ) + else: + candidate_embeddings = candidate_embeddings_masked + + return ( + seq_embeddings if self.return_full_embeddings else None, + candidate_embeddings, + ) + + def __call__( + self, + max_uih_len: int, + max_targets: int, + total_uih_len: int, + total_targets: int, + seq_lengths: jnp.ndarray, + seq_embeddings: jnp.ndarray, + seq_timestamps: jnp.ndarray, + num_targets: jnp.ndarray, + seq_payloads: Dict[str, jnp.ndarray], + *, + deterministic: Optional[bool] = None, + decode: bool = False, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + ]: + """Forward pass for HSTUTransducer.""" + if decode and not deterministic: + raise ValueError("If decode=True, deterministic must be True.") + is_training = ( + not deterministic if deterministic is None else not deterministic + ) + + batch_size, max_len, _ = seq_embeddings.shape + seq_mask = ( + jnp.arange(max_len, dtype=jnp.int32)[None, :] < seq_lengths[:, None] + ) + + ( + processed_seq_embeddings, + processed_seq_mask, + processed_seq_timestamps, + processed_num_targets, + processed_seq_payloads, + ) = self._preprocess( + max_uih_len=max_uih_len, + seq_embeddings=seq_embeddings, + seq_mask=seq_mask, + seq_timestamps=seq_timestamps, + num_targets=num_targets, + seq_payloads=seq_payloads, + is_training=is_training, + ) + + encoded_embeddings = self._hstu_compute( + seq_embeddings=processed_seq_embeddings, + num_targets=processed_num_targets, + is_training=is_training, + decode=decode, + ) + + encoded_embeddings = ( + encoded_embeddings * processed_seq_mask[..., jnp.newaxis] + ) + + full_embeddings, candidate_embeddings = self._postprocess( + seq_embeddings=encoded_embeddings, + seq_mask=processed_seq_mask, + seq_timestamps=processed_seq_timestamps, + num_targets=processed_num_targets, + seq_payloads=processed_seq_payloads, + ) + + return candidate_embeddings, full_embeddings diff --git a/recml/examples/DLRM_HSTU/movielens_dataloader.py b/recml/examples/DLRM_HSTU/movielens_dataloader.py new file mode 100644 index 0000000..c2c2983 --- /dev/null +++ b/recml/examples/DLRM_HSTU/movielens_dataloader.py @@ -0,0 +1,181 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dataloader for MovieLens dataset using jax_recommenders.""" + +import jax.numpy as jnp +import pandas as pd + +USER_ID = 'user_id' +ITEM_ID = 'item_id' +TIMESTAMP = 'timestamp' +USER_RATING = 'user_rating' + + +class MovieLensDataLoader: + """Dataloader for MovieLens dataset.""" + + def __init__( + self, + batch_size, + max_uih_len, + max_candidates, + raw_df: pd.DataFrame, + ): + self.batch_size = batch_size + self.max_uih_len = max_uih_len + self.max_candidates = max_candidates + + if raw_df is None: + raise ValueError('raw_df must be provided') + self.raw_df = raw_df + self._create_vocabs() + self.processed_data = self._preprocess_data() + + def _create_vocabs(self): + """Creates vocabularies from the raw dataframe.""" + self.user_vocab = sorted(self.raw_df[USER_ID].unique()) + self.movie_vocab = sorted(self.raw_df[ITEM_ID].unique()) + # Movielens ratings are 0.5 to 5.0. We can map them to 0-9 + self.rating_vocab = sorted(self.raw_df[USER_RATING].unique()) + + self.user_map = {name: i for i, name in enumerate(self.user_vocab)} + self.movie_map = {name: i for i, name in enumerate(self.movie_vocab)} + self.rating_map = {name: i for i, name in enumerate(self.rating_vocab)} + + self.user_vocab_size = len(self.user_vocab) + self.movie_vocab_size = len(self.movie_vocab) + self.rating_vocab_size = len(self.rating_vocab) + self.genre_vocab_size = 1 # Genre not directly used in this simple version + + def _pad_seq(self, seq, max_len, pad_value=0): + """Pads a sequence to max_len.""" + if len(seq) > max_len: + return seq[:max_len] + return seq + [pad_value] * (max_len - len(seq)) + + def _preprocess_data(self): + """Preprocesses the raw data into batches of UIH and candidates.""" + df = self.raw_df.copy() + df[USER_ID] = df[USER_ID].map(self.user_map) + df[ITEM_ID] = df[ITEM_ID].map(self.movie_map) + df[USER_RATING] = df[USER_RATING].map(self.rating_map) + + df = df.sort_values(by=[USER_ID, TIMESTAMP]) + grouped = df.groupby(USER_ID) + + batched_data = [] + current_batch = [] + + for user_id, user_df in grouped: + history = user_df[:-self.max_candidates] + candidates = user_df[-self.max_candidates:] + + if len(history) < 1 or len(candidates) < 1: + continue + + uih_len = min(len(history), self.max_uih_len) + num_cands = len(candidates) + + uih_features = { + 'user_id': self._pad_seq( + [user_id] * uih_len, self.max_uih_len, pad_value=0 + ), + 'movie_id': self._pad_seq( + history[ITEM_ID].tolist(), self.max_uih_len + ), + 'rating': self._pad_seq( + history[USER_RATING].tolist(), self.max_uih_len + ), + 'action_time': self._pad_seq( + history[TIMESTAMP].tolist(), self.max_uih_len + ), + 'uih_weight': self._pad_seq([1] * uih_len, self.max_uih_len, 0), + 'uih_watch_time': self._pad_seq( + history[USER_RATING].tolist(), self.max_uih_len, 0 + ), + } + + candidate_features = { + 'user_id': self._pad_seq( + [user_id] * num_cands, self.max_candidates, pad_value=0 + ), + 'movie_id': self._pad_seq( + candidates[ITEM_ID].tolist(), self.max_candidates + ), + 'query_time': self._pad_seq( + candidates[TIMESTAMP].tolist(), self.max_candidates + ), + # candidates_weight is used as a mask for valid candidates in the loss + # calculation. + 'candidates_weight': self._pad_seq( + [1] * num_cands, self.max_candidates, 0 + ), + # candidates_watch_time carries the true rating values for the + # candidate items, which are used as labels for the regression task + # in the MultitaskModule. + 'candidates_watch_time': self._pad_seq( + candidates[USER_RATING].tolist(), self.max_candidates, 0 + ), + } + + current_batch.append({ + 'uih_features': uih_features, + 'candidate_features': candidate_features, + 'uih_lengths': uih_len, + 'num_candidates': num_cands, + }) + + if len(current_batch) == self.batch_size: + batched_data.append(self._collate_batch(current_batch)) + current_batch = [] + + # Add the last partial batch if any + if current_batch: + # To keep things simple for the test, we'll drop the last partial batch + # pass # batched_data.append(self._collate_batch(current_batch)) + pass + return batched_data + + def _collate_batch(self, batch): + """Collates a list of samples into a single batch of numpy arrays.""" + collated = {} + if not batch: + return collated + + keys = batch[0].keys() + + for key in keys: + example_value = batch[0][key] + if isinstance(example_value, dict): + collated[key] = {} + sub_keys = example_value.keys() + for sub_key in sub_keys: + collated[key][sub_key] = jnp.array( + [sample[key][sub_key] for sample in batch] + ) + elif isinstance(example_value, int): + collated[key] = jnp.array([sample[key] for sample in batch]) + else: + # Handle other potential types if necessary + pass + return collated + + def get_batch(self, idx): + """Returns a single batch by index.""" + if idx >= len(self.processed_data): + raise IndexError("Batch index out of range") + return self.processed_data[idx] + + def __len__(self): + return len(self.processed_data) diff --git a/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py b/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py new file mode 100644 index 0000000..6f33698 --- /dev/null +++ b/recml/examples/DLRM_HSTU/movielens_dlrm_hstu_test.py @@ -0,0 +1,310 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DLRM with MovieLens dataset.""" + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +import jax +import numpy as np +import optax +import pandas as pd +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTU +from recml.examples.DLRM_HSTU.dlrm_hstu import DlrmHSTUConfig +from recml.examples.DLRM_HSTU.dlrm_hstu import EmbeddingConfig +from recml.examples.DLRM_HSTU.movielens_dataloader import MovieLensDataLoader +from recml.examples.DLRM_HSTU.multitask_module import MultitaskTaskType +from recml.examples.DLRM_HSTU.multitask_module import TaskConfig +from recml.layers.linen import sparsecore +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec + + +USER_ID = 'user_id' +ITEM_ID = 'item_id' +TIMESTAMP = 'timestamp' +USER_RATING = 'user_rating' + + +def create_dummy_movielens_df(num_users, num_items, num_events): + user_ids = np.random.randint(0, num_users, num_events) + item_ids = np.random.randint(0, num_items, num_events) + ratings = np.random.uniform(0.5, 5.0, num_events).round(1) + timestamps = np.arange(num_events) * 1000 # Increasing timestamps + df = pd.DataFrame({ + USER_ID: [f'user_{u}' for u in user_ids], + ITEM_ID: [f'item_{i}' for i in item_ids], + USER_RATING: ratings, + TIMESTAMP: timestamps, + }) + return df + + +class DlrmMovielensTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.batch_size = 16 + self.max_uih_len = 16 + self.max_candidates = 4 + self.embed_dim = 32 + self.hstu_dim = 64 + + dummy_df = create_dummy_movielens_df( + num_users=50, num_items=100, num_events=500 + ) + + self.dataloader = MovieLensDataLoader( + self.batch_size, + self.max_uih_len, + self.max_candidates, + raw_df=dummy_df, + ) + + self.user_vocab = self.dataloader.user_vocab_size + self.movie_vocab = self.dataloader.movie_vocab_size + self.rating_vocab = self.dataloader.rating_vocab_size + self.genre_vocab = self.dataloader.genre_vocab_size + + seq_len = self.max_uih_len + self.max_candidates + sc_config = sparsecore.SparsecoreConfig( + specs={ + 'user_id': sparsecore.EmbeddingSpec( + input_dim=self.user_vocab, + embedding_dim=self.embed_dim, + max_sequence_length=seq_len, + ), + 'movie_id': sparsecore.EmbeddingSpec( + input_dim=self.movie_vocab, + embedding_dim=self.embed_dim, + max_sequence_length=seq_len, + ), + 'rating': sparsecore.EmbeddingSpec( + input_dim=self.rating_vocab, + embedding_dim=self.embed_dim, + max_sequence_length=seq_len, + ), + }, + optimizer=embedding_spec.AdagradOptimizerSpec(learning_rate=0.01), + ) + + self.config = DlrmHSTUConfig( + sparsecore_config=sc_config, + max_seq_len=self.max_uih_len + self.max_candidates, + hstu_embedding_table_dim=self.embed_dim, + hstu_transducer_embedding_dim=self.hstu_dim, + hstu_preprocessor_hidden_dim=16, + hstu_attn_num_layers=2, + hstu_attn_linear_dim=16, + hstu_attn_qk_dim=16, + item_embedding_feature_names=['movie_id'], + user_embedding_feature_names=['user_id'], + uih_post_id_feature_name='movie_id', + uih_action_time_feature_name='action_time', + candidates_querytime_feature_name='query_time', + candidates_weight_feature_name='candidates_weight', + candidates_watchtime_feature_name='candidates_watch_time', + uih_weight_feature_name='uih_weight', + action_weights=[1], + merge_uih_candidate_feature_mapping=[ + ('movie_id', 'movie_id'), + ('rating', 'rating'), + ('action_time', 'query_time'), + ('user_id', 'user_id'), + ('uih_weight', 'candidates_weight'), + ('uih_watch_time', 'candidates_watch_time'), + ], + multitask_configs=[ + TaskConfig('RatingPrediction', 1, MultitaskTaskType.REGRESSION), + ], + contextual_feature_to_max_length={}, + additional_content_features={}, + target_enrich_features={}, + ) + if jax.devices()[0].platform == 'tpu': + self.mesh = jax.sharding.Mesh(np.array(jax.devices()), ('data',)) + else: + self.mesh = None + + @parameterized.named_parameters( + ('train', False), + ('eval', True), + ) + def test_forward_pass(self, deterministic): + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + key = jax.random.PRNGKey(0) + prng_keys = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, mesh=self.mesh) + + if not self.dataloader: + self.skipTest( + 'No batches were created, potentially too few users or max_candidates' + ' too high for debug data.' + ) + + batch = self.dataloader.get_batch(0) + uih_features = batch['uih_features'] + candidate_features = batch['candidate_features'] + uih_lengths = batch['uih_lengths'] + num_candidates = batch['num_candidates'] + + features = uih_features | candidate_features + features['movie_id'] = np.concatenate( + [uih_features['movie_id'], candidate_features['movie_id']], axis=1 + ) + features['user_id'] = np.concatenate( + [uih_features['user_id'], candidate_features['user_id']], axis=1 + ) + features['rating'] = np.concatenate( + [ + uih_features['rating'], + np.zeros( + (self.batch_size, self.max_candidates), + dtype=uih_features['rating'].dtype, + ), + ], + axis=1, + ) + features['uih_weight'] = np.concatenate( + [uih_features['uih_weight'], candidate_features['candidates_weight']], + axis=1, + ) + + preprocessor = sparsecore.SparsecorePreprocessor( + self.config.sparsecore_config, self.batch_size + ) + sc_features = preprocessor(features) + + variables = model.init( + {'params': prng_keys[0], 'dropout': prng_keys[1]}, + sc_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + ) + + user_emb, item_emb, aux_losses, preds, labels, weights = model.apply( + variables, + sc_features, + uih_lengths, + num_candidates, + deterministic=deterministic, + rngs={'dropout': prng_keys[2]} if not deterministic else None, + ) + + num_tasks = len(self.config.multitask_configs) + self.assertEqual( + user_emb.shape, (self.batch_size, self.max_candidates, self.hstu_dim) + ) + self.assertEqual( + item_emb.shape, (self.batch_size, self.max_candidates, self.hstu_dim) + ) + self.assertEqual( + preds.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + if not deterministic: + self.assertNotEmpty(aux_losses) + self.assertEqual( + labels.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + self.assertEqual( + weights.shape, (num_tasks, self.batch_size, self.max_candidates) + ) + + def test_backward_pass(self): + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + key = jax.random.PRNGKey(1) + init_key, data_key, train_key = jax.random.split(key, 3) + model = DlrmHSTU(hstu_configs=self.config, mesh=self.mesh) + + if not self.dataloader: + self.skipTest( + 'No batches were created, potentially too few users or max_candidates' + ' too high for debug data.' + ) + + batch = self.dataloader.get_batch(0) + uih_features = batch['uih_features'] + candidate_features = batch['candidate_features'] + uih_lengths = batch['uih_lengths'] + num_candidates = batch['num_candidates'] + + features = uih_features | candidate_features + features['movie_id'] = np.concatenate( + [uih_features['movie_id'], candidate_features['movie_id']], axis=1 + ) + features['user_id'] = np.concatenate( + [uih_features['user_id'], candidate_features['user_id']], axis=1 + ) + features['rating'] = np.concatenate( + [ + uih_features['rating'], + np.zeros( + (self.batch_size, self.max_candidates), + dtype=uih_features['rating'].dtype, + ), + ], + axis=1, + ) + features['uih_weight'] = np.concatenate( + [uih_features['uih_weight'], candidate_features['candidates_weight']], + axis=1, + ) + + preprocessor = sparsecore.SparsecorePreprocessor( + self.config.sparsecore_config, self.batch_size + ) + sc_features = preprocessor(features) + + variables = model.init( + {'params': init_key, 'dropout': train_key}, + sc_features, + uih_lengths, + num_candidates, + deterministic=False, + ) + params = variables['params'] + cache = variables['cache'] + + optimizer = optax.adam(learning_rate=1e-3) + opt_state = optimizer.init(params) + + def loss_fn(params, dropout_key): + _, _, aux_losses, _, _, _ = model.apply( + {'params': params, 'cache': cache}, + sc_features, + uih_lengths, + num_candidates, + deterministic=False, + rngs={'dropout': dropout_key}, + ) + return sum(val.sum() for val in aux_losses.values()) + + @jax.jit + def train_step(params, opt_state, dropout_key): + loss, grads = jax.value_and_grad(loss_fn)(params, dropout_key) + updates, opt_state = optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + return params, opt_state, loss + + step_key, train_key = jax.random.split(train_key) + params, opt_state, loss = train_step(params, opt_state, step_key) + logging.info('MovieLens Test Loss: %f', loss) + self.assertIsNotNone(params) + + +if __name__ == '__main__': + absltest.main() diff --git a/recml/examples/DLRM_HSTU/multitask_module.py b/recml/examples/DLRM_HSTU/multitask_module.py new file mode 100644 index 0000000..510aeb6 --- /dev/null +++ b/recml/examples/DLRM_HSTU/multitask_module.py @@ -0,0 +1,273 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains modules and functions for handling multitask predictions and losses.""" + +import abc +from dataclasses import dataclass +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple + +from flax import linen as nn +import jax.numpy as jnp +import numpy as np +import optax + + +# These data classes are pure Python and can be used directly. +class MultitaskTaskType(IntEnum): + BINARY_CLASSIFICATION = 0 + REGRESSION = 1 + + +@dataclass +class TaskConfig: + task_name: str + task_weight: int + task_type: MultitaskTaskType + + +class MultitaskModule(nn.Module): + """Abstract base class for multitask modules in Flax.""" + + def __call__( + self, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ]: + """Computes multi-task predictions. + + Args: + encoded_user_embeddings: (B, N, D) float array. + item_embeddings: (B, N, D) float array. + supervision_labels: Dictionary of (B, N) float or int arrays. + supervision_weights: Dictionary of (B, N) float or int arrays. + deterministic: If True, losses are not computed (inference mode). + + Returns: + A tuple of (predictions, labels, weights, losses). + Predictions are of shape (num_tasks, B, N). + """ + raise NotImplementedError + + +def _compute_pred_and_logits( + prediction_module: nn.Module, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + task_offsets: List[int], + has_multiple_task_types: bool, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes predictions and raw logits from user and item embeddings.""" + # Logits are computed by applying the prediction module to the + # element-wise product. + # Input shape: (B, N, D), Output shape: (B, N, num_tasks) + mt_logits_unnposed = prediction_module( + encoded_user_embeddings * item_embeddings + ) + # Transpose to (num_tasks, B, N) to match PyTorch logic. + mt_logits = jnp.transpose(mt_logits_unnposed, (2, 0, 1)) + + mt_preds_list: List[jnp.ndarray] = [] + for task_type in MultitaskTaskType: + start_offset, end_offset = ( + task_offsets[task_type], + task_offsets[task_type + 1], + ) + if end_offset > start_offset: + task_logits = mt_logits[start_offset:end_offset, ...] + if task_type == MultitaskTaskType.REGRESSION: + # For regression, predictions are the raw logits. + mt_preds_list.append(task_logits) + else: + # For classification, predictions are the sigmoid of the logits. + mt_preds_list.append(nn.sigmoid(task_logits)) + + mt_preds = ( + jnp.concatenate(mt_preds_list, axis=0) + if has_multiple_task_types + else mt_preds_list[0] + ) + + return mt_preds, mt_logits + + +def _compute_labels_and_weights( + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + task_configs: List[TaskConfig], +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Aggregates label and weight tensors from input dictionaries.""" + # Get a sample tensor to determine shape and dtype for the default weight. + first_label = next(iter(supervision_labels.values())) + default_supervision_weight = jnp.ones_like(first_label) + + mt_labels_list: List[jnp.ndarray] = [] + mt_weights_list: List[jnp.ndarray] = [] + for task in task_configs: + mt_labels_list.append(supervision_labels[task.task_name]) + mt_weights_list.append( + supervision_weights.get(task.task_name, default_supervision_weight) + ) + + # Stack along a new 'task' dimension. + mt_labels = jnp.stack(mt_labels_list, axis=0) + mt_weights = jnp.stack(mt_weights_list, axis=0) + + return mt_labels, mt_weights + + +def _compute_loss( + task_offsets: List[int], + causal_multitask_weights: float, + mt_logits: jnp.ndarray, + mt_labels: jnp.ndarray, + mt_weights: jnp.ndarray, + has_multiple_task_types: bool, +) -> jnp.ndarray: + """Computes the final loss across all tasks.""" + mt_losses_list: List[jnp.ndarray] = [] + for task_type in MultitaskTaskType: + start_offset, end_offset = ( + task_offsets[task_type], + task_offsets[task_type + 1], + ) + if end_offset > start_offset: + task_logits = mt_logits[start_offset:end_offset, ...] + task_labels = mt_labels[start_offset:end_offset, ...] + task_weights = mt_weights[start_offset:end_offset, ...] + + if task_type == MultitaskTaskType.REGRESSION: + # Equivalent to mse_loss with reduction='none'. + task_losses = (task_logits - task_labels) ** 2 + else: + # Equivalent to binary_cross_entropy_with_logits with reduction='none'. + task_losses = optax.sigmoid_binary_cross_entropy( + task_logits, task_labels + ) + + # Apply task-specific weights. + mt_losses_list.append(task_losses * task_weights) + + mt_losses = ( + jnp.concatenate(mt_losses_list, axis=0) + if has_multiple_task_types + else mt_losses_list[0] + ) + + # Normalize loss per task by the sum of weights for that task. + # Sum over the item dimension (axis=-1). + sum_losses = mt_losses.sum(axis=-1) + sum_weights = mt_weights.sum(axis=-1) + + # Clamp sum_weights to avoid division by zero for empty examples. + normalized_losses = sum_losses / jnp.maximum(sum_weights, 1.0) + + # Apply a global weight for this entire multitask head. + return normalized_losses * causal_multitask_weights + + +class DefaultMultitaskModule(MultitaskModule): + """ + JAX/Flax implementation of the default multitask module. + + Attributes: + task_configs: A list of TaskConfig objects, which must be pre-sorted + by task_type. + embedding_dim: The dimensionality of the input embeddings. + prediction_fn: A function that returns a Flax module for predictions, + e.g., a simple MLP. It takes embedding_dim and num_tasks as input. + causal_multitask_weights: A global weight for the final computed loss. + """ + task_configs: List[TaskConfig] + embedding_dim: int + prediction_fn: Callable[[int, int], nn.Module] + causal_multitask_weights: float + + def setup(self): + if not self.task_configs: + raise ValueError("task_configs must be non-empty.") + + # Check if tasks are sorted by type, as required by the original logic. + is_sorted = all( + self.task_configs[i].task_type <= self.task_configs[i + 1].task_type + for i in range(len(self.task_configs) - 1) + ) + if not is_sorted: + raise ValueError("task_configs must be sorted by task_type.") + + # Calculate offsets for slicing tensors based on task type. + task_offsets_list = [0] * (len(MultitaskTaskType) + 1) + for task in self.task_configs: + task_offsets_list[task.task_type + 1] += 1 + + self._has_multiple_task_types: bool = ( + task_offsets_list.count(0) < len(MultitaskTaskType) + ) + self._task_offsets: List[int] = np.cumsum(task_offsets_list).tolist() + + # Instantiate the prediction module. + self._prediction_module = self.prediction_fn( + self.embedding_dim, len(self.task_configs) + ) + + def __call__( + self, + encoded_user_embeddings: jnp.ndarray, + item_embeddings: jnp.ndarray, + supervision_labels: Dict[str, jnp.ndarray], + supervision_weights: Dict[str, jnp.ndarray], + deterministic: bool, + ) -> Tuple[ + jnp.ndarray, + Optional[jnp.ndarray], + Optional[jnp.ndarray], + Optional[jnp.ndarray], + ]: + + mt_preds, mt_logits = _compute_pred_and_logits( + prediction_module=self._prediction_module, + encoded_user_embeddings=encoded_user_embeddings, + item_embeddings=item_embeddings, + task_offsets=self._task_offsets, + has_multiple_task_types=self._has_multiple_task_types, + ) + + mt_labels: Optional[jnp.ndarray] = None + mt_weights: Optional[jnp.ndarray] = None + mt_losses: Optional[jnp.ndarray] = None + + if not deterministic: + mt_labels, mt_weights = _compute_labels_and_weights( + supervision_labels=supervision_labels, + supervision_weights=supervision_weights, + task_configs=self.task_configs, + ) + mt_losses = _compute_loss( + task_offsets=self._task_offsets, + causal_multitask_weights=self.causal_multitask_weights, + mt_logits=mt_logits, + mt_labels=mt_labels, + mt_weights=mt_weights, + has_multiple_task_types=self._has_multiple_task_types, + ) + + return mt_preds, mt_labels, mt_weights, mt_losses diff --git a/recml/examples/DLRM_HSTU/positional_encoder.py b/recml/examples/DLRM_HSTU/positional_encoder.py new file mode 100644 index 0000000..34c8f99 --- /dev/null +++ b/recml/examples/DLRM_HSTU/positional_encoder.py @@ -0,0 +1,242 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAX implementation of positional and timestamp encoding for sequences.""" + +from math import sqrt +from typing import Optional + +from flax import linen as nn +import jax +import jax.numpy as jnp + + +def _get_col_indices( + max_seq_len: int, + max_contextual_seq_len: int, + max_pos_ind: int, + seq_lengths: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + interleave_targets: bool, +) -> jnp.ndarray: + """Calculates the positional indices for each element in the sequence. + + JAX translation of `_get_col_indices` from pt_position.py. + + Args: + max_seq_len: The maximum sequence length. + max_contextual_seq_len: The maximum length of the contextual prefix. + max_pos_ind: The maximum positional index. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + interleave_targets: A boolean indicating whether to interleave targets. + + Returns: + A 2D tensor of shape (batch_size, max_seq_len) containing the positional + indices for each element in the sequence. + """ + batch_size = seq_lengths.shape[0] + col_indices = jnp.tile( + jnp.arange(max_seq_len, dtype=jnp.int32), (batch_size, 1) + ) + + if num_targets is not None: + if interleave_targets: + high_inds = seq_lengths - num_targets * 2 + else: + high_inds = seq_lengths - num_targets + + col_indices = jnp.minimum(col_indices, high_inds[:, jnp.newaxis]) + col_indices = high_inds[:, jnp.newaxis] - col_indices + else: + col_indices = seq_lengths[:, jnp.newaxis] - col_indices + + col_indices = col_indices + max_contextual_seq_len + col_indices = jnp.clip(col_indices, a_min=0, a_max=max_pos_ind - 1) + + if max_contextual_seq_len > 0: + contextual_indices = jnp.arange(max_contextual_seq_len, dtype=jnp.int32)[ + jnp.newaxis, : + ] + col_indices = col_indices.at[:, :max_contextual_seq_len].set( + contextual_indices + ) + + return col_indices + + +def add_timestamp_positional_embeddings( + seq_embeddings: jnp.ndarray, + pos_embeddings: jnp.ndarray, + ts_embeddings: jnp.ndarray, + timestamps: jnp.ndarray, + max_seq_len: int, + max_contextual_seq_len: int, + seq_lengths: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + interleave_targets: bool, + time_bucket_fn: str, +) -> jnp.ndarray: + """Adds timestamp and positional embeddings to sequence embeddings. + + JAX translation of `pytorch_add_timestamp_positional_embeddings`. Assumes + inputs are padded dense tensors. + + Args: + seq_embeddings: A 3D padded tensor of shape (batch_size, max_seq_len, + embedding_dim) containing the input item embeddings. + pos_embeddings: The learned positional embedding weights. + ts_embeddings: The learned timestamp embedding weights. + timestamps: A 2D padded tensor of shape (batch_size, max_seq_len) + containing timestamps for each item. + max_seq_len: The maximum sequence length for padding. + max_contextual_seq_len: The maximum length of the contextual prefix. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + interleave_targets: A boolean indicating whether to interleave targets. + time_bucket_fn: The function to use for time bucketing ("log" or "sqrt"). + + Returns: + A 3D tensor of the same shape as `seq_embeddings` with positional + and time embeddings added. + """ + # Position encoding + max_pos_ind = pos_embeddings.shape[0] + pos_inds = _get_col_indices( + max_seq_len=max_seq_len, + max_contextual_seq_len=max_contextual_seq_len, + max_pos_ind=max_pos_ind, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=interleave_targets, + ) + position_embeddings = pos_embeddings[pos_inds] + + # Timestamp encoding + batch_size = seq_lengths.shape[0] + num_time_buckets = ts_embeddings.shape[0] - 1 + time_bucket_increments = 60.0 + time_bucket_divisor = 1.0 + time_delta = 0 + + # Get the last valid timestamp from each padded sequence for query_time + query_indices = jnp.maximum(0, seq_lengths - 1) + query_time = timestamps[jnp.arange(batch_size), query_indices][:, jnp.newaxis] + + ts = query_time - timestamps + ts = ts + time_delta + ts = jnp.maximum(ts, 1e-6) / time_bucket_increments + + if time_bucket_fn == "log": + ts = jnp.log(ts) + elif time_bucket_fn == "sqrt": + ts = jnp.sqrt(ts) + else: + raise ValueError(f"Unsupported time_bucket_fn: {time_bucket_fn}") + + ts = (ts / time_bucket_divisor).clip(min=0).astype(jnp.int32) + ts = jnp.clip(ts, a_min=0, a_max=num_time_buckets) + + time_embeddings = ts_embeddings[ts] + + # Combine embeddings + added_embeddings = (position_embeddings + time_embeddings).astype( + seq_embeddings.dtype + ) + + # The original op implies addition to only the valid (non-padded) parts. + # In a dense representation, this is equivalent to masking the added + # embeddings. + mask = ( + jnp.arange(max_seq_len, dtype=jnp.int32)[jnp.newaxis, :] + < seq_lengths[:, jnp.newaxis] + ) + masked_added_embeddings = added_embeddings * mask[..., jnp.newaxis] + + return seq_embeddings + masked_added_embeddings + + +class HSTUPositionalEncoder(nn.Module): + """JAX implementation of HSTUPositionalEncoder. + + This module computes and adds positional and timestamp-based embeddings + to a sequence of input embeddings. + + Attributes: + num_position_buckets: The total number of position buckets. + num_time_buckets: The total number of time buckets. + embedding_dim: The dimensionality of the embeddings. + contextual_seq_len: The length of the contextual prefix in sequences. + """ + + num_position_buckets: int + num_time_buckets: int + embedding_dim: int + contextual_seq_len: int + + @nn.compact + def __call__( + self, + max_seq_len: int, + seq_lengths: jnp.ndarray, + seq_timestamps: jnp.ndarray, + seq_embeddings: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + ) -> jnp.ndarray: + """Adds positional and timestamp embeddings to the input sequence embeddings. + + Args: + max_seq_len: The maximum sequence length for padding. + seq_lengths: A 1D tensor of shape (batch_size,) with the true length of + each sequence. + seq_timestamps: A 2D padded tensor of shape (batch_size, max_seq_len) + containing timestamps for each item. + seq_embeddings: A 3D padded tensor of shape (batch_size, max_seq_len, + embedding_dim) containing the input item embeddings. + num_targets: An optional 1D tensor of shape (batch_size,) indicating the + number of target items at the end of each sequence. + + Returns: + A 3D tensor of the same shape as `seq_embeddings` with positional + and time embeddings added. + """ + position_embeddings_weight = self.param( + "_position_embeddings_weight", + nn.initializers.uniform(scale=sqrt(1.0 / self.num_position_buckets)), + (self.num_position_buckets, self.embedding_dim), + ) + timestamp_embeddings_weight = self.param( + "_timestamp_embeddings_weight", + nn.initializers.uniform(scale=sqrt(1.0 / self.num_time_buckets)), + (self.num_time_buckets + 1, self.embedding_dim), + ) + + scaled_seq_embeddings = seq_embeddings * sqrt(self.embedding_dim) + + final_embeddings = add_timestamp_positional_embeddings( + seq_embeddings=scaled_seq_embeddings, + pos_embeddings=position_embeddings_weight, + ts_embeddings=timestamp_embeddings_weight, + timestamps=seq_timestamps, + max_seq_len=max_seq_len, + max_contextual_seq_len=self.contextual_seq_len, + seq_lengths=seq_lengths, + num_targets=num_targets, + interleave_targets=False, + time_bucket_fn="sqrt", + ) + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/postprocessors.py b/recml/examples/DLRM_HSTU/postprocessors.py new file mode 100644 index 0000000..59696e0 --- /dev/null +++ b/recml/examples/DLRM_HSTU/postprocessors.py @@ -0,0 +1,171 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Postprocessors for user embeddings after HSTU layers.""" + +import math +from typing import Any, Dict, List, Tuple + +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax.numpy as jnp + + +Array = jnp.ndarray +Dtype = Any + + +class OutputPostprocessor(nn.Module): + """An abstract class for post-processing user embeddings after HSTU layers.""" + + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + """Processes the final sequence embeddings. + + Args: + seq_embeddings: (B, N, D) or (L, D) final embeddings from the model. + seq_timestamps: (B, N) or (L,) corresponding timestamps. + seq_payloads: A dictionary of other features. + + Returns: + The post-processed sequence embeddings. + """ + raise NotImplementedError + + +class L2NormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with L2 normalization.""" + epsilon: float = 1e-6 + + @nn.compact + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + norm = jnp.linalg.norm(seq_embeddings, ord=2, axis=-1, keepdims=True) + # Prevent division by zero + safe_norm = jnp.maximum(norm, self.epsilon) + return seq_embeddings / safe_norm + + +class LayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with LayerNorm.""" + embedding_dim: int + eps: float = 1e-5 + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + ln = nn.LayerNorm(epsilon=self.eps, dtype=self.dtype) + return ln(seq_embeddings) + + +class TimestampLayerNormPostprocessor(OutputPostprocessor): + """Postprocesses user embeddings with a timestamp-based MLP and LayerNorm.""" + embedding_dim: int + time_duration_features: List[Tuple[int, int]] + eps: float = 1e-5 + dtype: Dtype = jnp.float32 + + def setup(self): + self._layer_norm = nn.LayerNorm(epsilon=self.eps, dtype=self.dtype) + + num_time_features = len(self.time_duration_features) + combiner_input_dim = self.embedding_dim + 2 * num_time_features + + self._time_feature_combiner = nn.Dense( + features=self.embedding_dim, + dtype=self.dtype, + kernel_init=xavier_uniform(), + bias_init=zeros, + ) + + # Store time feature constants directly. No need for buffers in Flax. + self._period_units = jnp.array( + [f[0] for f in self.time_duration_features], dtype=self.dtype + ) + self._units_per_period = jnp.array( + [f[1] for f in self.time_duration_features], dtype=self.dtype + ) + + def __call__( + self, + seq_embeddings: Array, + seq_timestamps: Array, + seq_payloads: Dict[str, Array], + ) -> Array: + """Processes sequence embeddings with timestamp features and LayerNorm. + + Creates circular time features, concatenates them to the embeddings, + processes through an MLP, and applies LayerNorm. + + Args: + seq_embeddings: (B, N, D) or (L, D) final embeddings from the model. + seq_timestamps: (B, N) or (L,) corresponding timestamps. + seq_payloads: A dictionary of other features. + + Returns: + The post-processed sequence embeddings. + """ + + # 1. Create circular time features from timestamps. + # Ensure timestamps have a feature dimension for broadcasting. + if seq_timestamps.ndim != seq_embeddings.ndim: + timestamps = jnp.expand_dims(seq_timestamps, axis=-1) + else: + timestamps = seq_timestamps + + # Ensure correct broadcast shape for time constants. + # Original shape: (num_features,) -> (1, ..., 1, num_features) + broadcast_shape = (1,) * (timestamps.ndim - 1) + (-1,) + period_units = self._period_units.reshape(broadcast_shape) + units_per_period = self._units_per_period.reshape(broadcast_shape) + + # Calculate the phase angle for the circular representation. + units_since_epoch = jnp.floor(timestamps / period_units) + remainder = jnp.remainder(units_since_epoch, units_per_period) + angle = (remainder / units_per_period) * 2 * math.pi + + # Create sin/cos features. Cast to float32 for precision if needed. + original_dtype = angle.dtype + if original_dtype != jnp.float32: + angle = angle.astype(jnp.float32) + + cos_features = jnp.cos(angle) + sin_features = jnp.sin(angle) + + time_features = jnp.stack([cos_features, sin_features], axis=-1) + + # New shape will have a final dimension of num_time_features * 2 + final_shape = seq_embeddings.shape[:-1] + (-1,) + time_features = time_features.reshape(final_shape).astype(original_dtype) + # 2. Concatenate with sequence embeddings. + combined_embeddings = jnp.concatenate( + [seq_embeddings, time_features], axis=-1 + ) + # 3. Process through the MLP and LayerNorm. + user_embeddings = self._time_feature_combiner(combined_embeddings) + final_embeddings = self._layer_norm(user_embeddings) + return final_embeddings diff --git a/recml/examples/DLRM_HSTU/preprocessors.py b/recml/examples/DLRM_HSTU/preprocessors.py new file mode 100644 index 0000000..17e5f15 --- /dev/null +++ b/recml/examples/DLRM_HSTU/preprocessors.py @@ -0,0 +1,131 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/usr/bin/env python3 + +"""Input preprocessors for HSTU models.""" + +from typing import Any, Dict, Tuple + +import flax.linen as nn +from flax.linen.initializers import xavier_uniform +from flax.linen.initializers import zeros +import jax.numpy as jnp +from recml.examples.DLRM_HSTU.action_encoder import ActionEncoder + + +Array = jnp.ndarray +Dtype = Any + + +class SwishLayerNorm(nn.Module): + """JAX/Flax implementation of SwishLayerNorm. + + Corresponds to generative_recommenders/ops/layer_norm.py -> SwishLayerNorm + The PyTorch implementation is: x * sigmoid(layer_norm(x)) + """ + + epsilon: float = 1e-5 + dtype: Dtype = jnp.float32 + + @nn.compact + def __call__(self, x: Array) -> Array: + """Applies swish layer normalization to the input.""" + ln = nn.LayerNorm( + epsilon=self.epsilon, + use_bias=True, + use_scale=True, + dtype=self.dtype, + ) + normed_x = ln(x) + return x * nn.sigmoid(normed_x) + + +class InputPreprocessor(nn.Module): + """An abstract class for pre-processing sequence embeddings before HSTU layers.""" + + def __call__( + self, + max_uih_len: int, + seq_embeddings: Array, + seq_mask: Array, + seq_timestamps: Array, + num_targets: Array, + seq_payloads: Dict[str, Array], + *, + deterministic: bool, + ) -> Tuple[Array, Array, Array, Array, Dict[str, Array]]: + """Processes input sequences and their features. + + Args: + max_uih_len: Maximum length of the user item history. + seq_embeddings: (B, N, D) Padded sequence embeddings. + seq_mask: (B, N) Boolean mask for seq_embeddings. + seq_timestamps: (B, N) Padded timestamps. + num_targets: (B,) Number of targets for each sequence. + seq_payloads: Dict of other features, also as padded tensors with + masks. + deterministic: Controls dropout behavior. + + Returns: + A tuple containing the processed ( + output_embeddings, + output_mask, + output_timestamps, + output_num_targets, + output_payloads + ). + """ + raise NotImplementedError + + def interleave_targets(self) -> bool: + return False + + +def get_contextual_input_embeddings( + seq_mask: Array, + seq_payloads: Dict[str, Array], + contextual_feature_to_max_length: Dict[str, int], + contextual_feature_to_min_uih_length: Dict[str, int], + dtype: Dtype, +) -> Array: + """Constructs the input for contextual embeddings from dense tensors. + + Args: + seq_mask: Boolean mask for the sequence. + seq_payloads: Dictionary of all feature tensors. + contextual_feature_to_max_length: Maps feature names to their max length. + contextual_feature_to_min_uih_length: Maps features to a min uih length + for them to be active. + dtype: Data type for the output. + + Returns: + A dense tensor of shape (batch_size, sum_of_dims). + """ + padded_values = [] + seq_lengths = jnp.sum(seq_mask, axis=1, dtype=jnp.int32) + + for key, max_len in contextual_feature_to_max_length.items(): + # Assuming the payload is already a dense tensor of shape (B, L, D) + v = seq_payloads[key].astype(dtype) + + min_uih_length = contextual_feature_to_min_uih_length.get(key, 0) + if min_uih_length > 0: + # Create a mask to zero out embeddings for sequences that are too short + mask = (seq_lengths >= min_uih_length).reshape(-1, 1, 1) + v *= mask + + # Flatten the feature dimension + padded_values.append(v.reshape(v.shape[0], -1)) + + return jnp.concatenate(padded_values, axis=1) diff --git a/recml/examples/DLRM_HSTU/stu.py b/recml/examples/DLRM_HSTU/stu.py new file mode 100644 index 0000000..82ec7a0 --- /dev/null +++ b/recml/examples/DLRM_HSTU/stu.py @@ -0,0 +1,357 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Self-Targeting Unit (STU) module. + +This module implements the STU layer and a stack of STU +layers. The STU layer is designed to capture long-range dependencies in +sequential data by incorporating self-attention mechanisms with a gating +mechanism. +""" + +import dataclasses +from typing import Optional, Sequence + +from flax import linen as nn +import jax +import jax.numpy as jnp + + +dataclass = dataclasses.dataclass + + +@dataclass +class STULayerConfig: + """Configuration for the STU layer. + + Attributes: + embedding_dim: Input embedding dimension. + num_heads: Number of attention heads. + hidden_dim: Hidden dimension of the STU layer. + attention_dim: Dimension of the attention projections. + output_dropout_ratio: Dropout ratio for the output. + causal: Whether to use causal attention. + target_aware: Whether to use target-aware attention. + max_attn_len: Maximum attention length. + attn_alpha: Scaling factor for attention scores. + use_group_norm: Whether to use group normalization. + recompute_normed_x: Whether to recompute normalized input. + recompute_uvqk: Whether to recompute u, v, q, k projections. + recompute_y: Whether to recompute the output. + sort_by_length: Whether to sort sequences by length. + contextual_seq_len: Contextual sequence length. + min_full_attn_seq_len: Minimum sequence length to apply full attention. + norm_epsilon: Epsilon value for normalization. + deterministic: Whether to apply dropout in deterministic mode. + max_decode_length: The maximum length for decoding. + """ + + embedding_dim: int + num_heads: int + hidden_dim: int + attention_dim: int + output_dropout_ratio: float = 0.0 + causal: bool = True + target_aware: bool = True + max_attn_len: Optional[int] = None + attn_alpha: Optional[float] = None + use_group_norm: bool = False + recompute_normed_x: bool = True + recompute_uvqk: bool = True + recompute_y: bool = True + sort_by_length: bool = True + contextual_seq_len: int = 0 + norm_epsilon: float = 1e-6 + min_full_attn_seq_len: int = 0 + deterministic: bool = True + max_decode_length: int = 2048 + + +class STULayer(nn.Module): + """Self-Targeting Unit layer. + + Attributes: + config: STULayerConfig, configuration of the STU layer. + """ + + config: STULayerConfig + + def setup(self): + self.num_heads: int = self.config.num_heads + self.embedding_dim: int = self.config.embedding_dim + self.hidden_dim: int = self.config.hidden_dim + self.attention_dim: int = self.config.attention_dim + self.output_dropout_ratio: float = self.config.output_dropout_ratio + self.target_aware: bool = self.config.target_aware + self.causal: bool = self.config.causal + self.max_attn_len: int = self.config.max_attn_len or 0 + self.attn_alpha: float = self.config.attn_alpha or 1.0 / ( + self.attention_dim**0.5 + ) + self.use_group_norm: bool = self.config.use_group_norm + self.norm_epsilon: float = self.config.norm_epsilon + self.contextual_seq_len: int = self.config.contextual_seq_len + self.min_full_attn_seq_len: int = self.config.min_full_attn_seq_len + + self.uvqk_weight = self.param( + '_uvqk_weight', + nn.initializers.xavier_normal(), + ( + self.embedding_dim, + (self.hidden_dim * 2 + self.attention_dim * 2) * self.num_heads, + ), + ) + self.uvqk_beta = self.param( + '_uvqk_beta', + nn.initializers.zeros, + (self.hidden_dim * 2 + self.attention_dim * 2) * self.num_heads, + ) + + self.output_weight = self.param( + '_output_weight', + nn.initializers.xavier_uniform(), + (self.hidden_dim * self.num_heads * 3, self.embedding_dim), + ) + + self.dropout_layer = nn.Dropout(rate=self.output_dropout_ratio) + self.group_norm_layer = nn.GroupNorm( + num_groups=self.num_heads, + use_scale=True, + use_bias=True, + epsilon=self.norm_epsilon, + ) + self.input_norm_layer = nn.LayerNorm( + use_scale=True, use_bias=True, epsilon=self.norm_epsilon + ) + self.output_norm_layer = nn.LayerNorm( + use_scale=True, use_bias=True, epsilon=self.norm_epsilon + ) + self.cached_key = self.variable('cache', 'cached_key', lambda: None) + self.cached_value = self.variable('cache', 'cached_value', lambda: None) + self.cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.zeros((), jnp.int32) + ) + + def _get_valid_attn_mask(self, x, num_targets: Optional[jnp.ndarray]): + batch_size, seq_len, _ = x.shape + seq_lengths = jnp.full((batch_size,), seq_len, dtype=jnp.int32) + ids = jnp.arange(seq_len)[None, :] + max_ids = seq_lengths[:, None, None] + + if self.contextual_seq_len > 0: + ids = ids - self.contextual_seq_len + 1 + ids = jnp.maximum(ids, 0) + max_ids = max_ids - self.contextual_seq_len + 1 + + if num_targets is not None: + max_ids = (max_ids - num_targets[:, None, None]).squeeze(axis=-1) + ids = jnp.minimum(ids, max_ids) + row_ids = ids[:, :, None] + col_ids = ids[:, None, :] + else: + row_ids_base = jnp.arange(seq_len)[None, :, None] + col_ids_base = jnp.arange(seq_len)[None, None, :] + row_ids = jnp.broadcast_to(row_ids_base, (1, seq_len, seq_len)) + col_ids = jnp.broadcast_to(col_ids_base, (1, seq_len, seq_len)) + + row_col_dist = row_ids - col_ids + valid_attn_mask = jnp.eye(seq_len, dtype=jnp.bool_)[None, :, :] + if not self.causal: + row_col_dist = jnp.abs(row_col_dist) + valid_attn_mask = jnp.logical_or(valid_attn_mask, row_col_dist > 0) + if self.max_attn_len > 0: + if self.min_full_attn_seq_len > 0: + valid_attn_mask = jnp.logical_and( + valid_attn_mask, + jnp.logical_or( + row_col_dist <= self.max_attn_len, + row_ids >= max_ids - self.min_full_attn_seq_len, + ), + ) + else: + valid_attn_mask = jnp.logical_and( + valid_attn_mask, row_col_dist <= self.max_attn_len + ) + if self.contextual_seq_len > 0: + valid_attn_mask = jnp.logical_or( + valid_attn_mask, jnp.logical_and(row_ids == 0, col_ids < max_ids) + ) + + return valid_attn_mask + + def hstu_compute_output(self, attn, u, x, deterministic: bool): + """Computes the output of the STU layer with corrected logic.""" + if self.use_group_norm: + norm_input = attn.reshape( + attn.shape[0], attn.shape[1], self.num_heads, self.hidden_dim + ) + normed_attn = self.group_norm_layer(norm_input).reshape( + attn.shape[0], attn.shape[1], -1 + ) + else: + normed_attn = self.output_norm_layer(attn) + + gated_attn = u * normed_attn + proj_input = jnp.concatenate([u, attn, gated_attn], axis=-1) + projected_output = proj_input @ self.output_weight + dropped_out = self.dropout_layer( + projected_output, deterministic=deterministic + ) + return x + dropped_out + + def hstu_preprocess_and_attention( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray], + deterministic: bool, + decode: bool = False, + ): + """Replicated STU preprocess and attention.""" + normed_x = self.input_norm_layer(x) + uvqk = normed_x @ self.uvqk_weight + self.uvqk_beta + u_proj, v_proj, q_proj, k_proj = jnp.split( + uvqk, + [ + self.hidden_dim * self.num_heads, + self.hidden_dim * self.num_heads * 2, + self.hidden_dim * self.num_heads * 2 + + self.attention_dim * self.num_heads, + ], + axis=-1, + ) + + u = nn.silu(u_proj) + batch_size, seq_len, _ = x.shape + q = q_proj.reshape( + batch_size, seq_len, self.num_heads, self.attention_dim + ).transpose(0, 2, 1, 3) + k = k_proj.reshape( + batch_size, seq_len, self.num_heads, self.attention_dim + ).transpose(0, 2, 1, 3) + v = v_proj.reshape( + batch_size, seq_len, self.num_heads, self.hidden_dim + ).transpose(0, 2, 1, 3) + + cache_index = 0 + if decode: + is_initialized = ( + self.has_variable('cache', 'cached_key') + and self.cached_key.value is not None + ) + cache_index = self.cache_index.value + if not is_initialized and self.is_mutable_collection('cache'): + k_cache_shape = ( + batch_size, + self.num_heads, + self.config.max_decode_length, + self.attention_dim, + ) + v_cache_shape = ( + batch_size, + self.num_heads, + self.config.max_decode_length, + self.hidden_dim, + ) + self.cached_key.value = jnp.zeros(k_cache_shape, k.dtype) + self.cached_value.value = jnp.zeros(v_cache_shape, v.dtype) + + if self.is_mutable_collection('cache'): + k_cache = jax.lax.dynamic_update_slice( + self.cached_key.value, + k.astype(self.cached_key.value.dtype), + (0, 0, cache_index, 0), + ) + v_cache = jax.lax.dynamic_update_slice( + self.cached_value.value, + v.astype(self.cached_value.value.dtype), + (0, 0, cache_index, 0), + ) + self.cached_key.value = k_cache + self.cached_value.value = v_cache + self.cache_index.value = cache_index + seq_len + k = k_cache + v = v_cache + elif is_initialized: + k = self.cached_key.value + v = self.cached_value.value + else: + raise ValueError('Cache not initialized and not mutable.') + + attn_scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) * self.attn_alpha + if decode: + attn_weights = nn.silu(attn_scores) / self.config.max_decode_length + mask = (jnp.arange(self.config.max_decode_length) <= cache_index)[ + None, None, None, : + ] + attn_weights = attn_weights * mask + else: + attn_weights = nn.silu(attn_scores) / seq_len + mask = self._get_valid_attn_mask(x, num_targets) + attn_weights = attn_weights * mask[:, None, :, :] + + attn_weights = self.dropout_layer(attn_weights, deterministic=deterministic) + attn_output = jnp.einsum('bhqk,bhkd->bhqd', attn_weights, v) + attn_output = attn_output.transpose(0, 2, 1, 3).reshape( + batch_size, seq_len, -1 + ) + + return u, attn_output, k_proj, v_proj + + def __call__( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray] = None, + deterministic: bool = True, + decode: bool = False, + ): + """Computes the STU layer.""" + actual_num_targets = num_targets if self.target_aware else None + u, attn_output, _, _ = self.hstu_preprocess_and_attention( + x, + actual_num_targets, + deterministic=deterministic, + decode=decode, + ) + final_output = self.hstu_compute_output( + attn=attn_output, u=u, x=x, deterministic=deterministic + ) + return final_output + + +class STUStack(nn.Module): + """STU stack. + + This module creates a stack of STU layers. + + Attributes: + configs: A sequence of STU layer configs. + """ + + configs: Sequence[STULayerConfig] + + def setup(self): + self.stu_layers = [ + STULayer(config=c, name=f'stu_layer_{i}') + for i, c in enumerate(self.configs) + ] + + def __call__( + self, + x: jnp.ndarray, + num_targets: Optional[jnp.ndarray] = None, + deterministic: bool = True, + decode: bool = False, + ): + for i in range(len(self.stu_layers)): + x = self.stu_layers[i](x, num_targets, deterministic, decode) + return x diff --git a/recml/examples/DLRM_HSTU/stu_test.py b/recml/examples/DLRM_HSTU/stu_test.py new file mode 100644 index 0000000..4df8dac --- /dev/null +++ b/recml/examples/DLRM_HSTU/stu_test.py @@ -0,0 +1,349 @@ +# Copyright 2024 RecML authors . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from absl import logging +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from jax.sharding import NamedSharding +from jax.sharding import PartitionSpec +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from flax import core +from recml.examples.DLRM_HSTU.stu import STULayer +from recml.examples.DLRM_HSTU.stu import STULayerConfig +from recml.examples.DLRM_HSTU.stu import STUStack + + +def get_test_configs(): + """Generates a list of test configurations.""" + test_params = [] + test_params.append(( + "basic_config", + { + "num_layers": 2, + "num_heads": 2, + "batch_size": 4, + "max_len": 32, + "embedding_dim": 16, + "attention_dim": 8, + "hidden_dim": 24, + "use_group_norm": False, + "target_aware": True, + }, + )) + test_params.append(( + "group_norm", + { + "num_layers": 1, + "num_heads": 4, + "batch_size": 2, + "max_len": 16, + "embedding_dim": 32, + "attention_dim": 16, + "hidden_dim": 20, + "use_group_norm": True, + "target_aware": True, + }, + )) + test_params.append(( + "not_target_aware", + { + "num_layers": 1, + "num_heads": 1, + "batch_size": 8, + "max_len": 64, + "embedding_dim": 8, + "attention_dim": 4, + "hidden_dim": 12, + "use_group_norm": False, + "target_aware": False, + }, + )) + test_params.append(( + "sliding_window_attention", + { + "num_layers": 1, + "num_heads": 2, + "batch_size": 2, + "max_len": 20, + "embedding_dim": 16, + "attention_dim": 8, + "hidden_dim": 16, + "use_group_norm": False, + "target_aware": True, + "max_attn_len": 5, + }, + )) + return test_params + + +class StuJaxTest(parameterized.TestCase): + """Unit tests for the JAX STU implementation.""" + + @classmethod + def setUpClass(cls): + super().setUpClass() + logging.info("Available devices: %s", jax.devices()) + # Assert that TPUs are available + assert any( + d.platform == "tpu" for d in jax.devices() + ), "No TPU devices found." + + def setUp(self): + """Set up a base key for all tests.""" + super().setUp() + self.key = jax.random.PRNGKey(42) + self.devices = jax.devices() + self.num_devices = len(self.devices) + self.mesh = Mesh(np.array(self.devices), ("data",)) + logging.info("Using device mesh: %s", self.mesh) + + self.batch_sharding = NamedSharding(self.mesh, PartitionSpec("data")) + self.replicated_sharding = NamedSharding(self.mesh, PartitionSpec()) + + @parameterized.named_parameters(get_test_configs()) + def test_output_shape_and_gradients(self, config_dict): + """Tests STUStack for output shape and valid gradients. + + This test verifies that the STUStack runs, produces the correct output + shape, and that gradients can be computed without errors (e.g., NaNs). + + Args: + config_dict: A dictionary containing the configuration parameters for the + STUStack. + """ + self.assertEqual(jax.devices()[0].platform, "tpu") + + config = STULayerConfig( + embedding_dim=config_dict["embedding_dim"], + num_heads=config_dict["num_heads"], + hidden_dim=config_dict["hidden_dim"], + attention_dim=config_dict["attention_dim"], + target_aware=config_dict["target_aware"], + use_group_norm=config_dict["use_group_norm"], + max_attn_len=config_dict.get("max_attn_len", 0), + ) + + stu_configs = [config for _ in range(config_dict['num_layers'])] + model = STUStack(configs=stu_configs) + + batch_size, max_len = config_dict["batch_size"], config_dict["max_len"] + + if batch_size % self.num_devices != 0: + batch_size = (batch_size // self.num_devices + 1) * self.num_devices + logging.warning("Adjusted batch size to %d for sharding", batch_size) + + init_key, data_key, dropout_key = jax.random.split(self.key, 3) + + dummy_x = jax.random.normal( + data_key, (batch_size, max_len, config.embedding_dim) + ) + dummy_num_targets = jax.random.randint( + data_key, (batch_size,), minval=1, maxval=5 + ) + + dummy_x = jax.device_put(dummy_x, self.batch_sharding) + dummy_num_targets = jax.device_put(dummy_num_targets, self.batch_sharding) + + variables = model.init( + {"params": init_key, "dropout": dropout_key}, + x=dummy_x, + num_targets=dummy_num_targets, + ) + variables = jax.device_put(variables, self.replicated_sharding) + + @jax.jit + def loss_fn(p, x, num_targets, rng_key): + y = model.apply( + {'params': p, 'cache': variables['cache']}, + x, + num_targets=num_targets, + rngs={"dropout": rng_key}, + ) + return jnp.sum(y**2) + + # Jitted apply function + apply_fn = jax.jit( + lambda v, x, num_targets: model.apply( + v, x, num_targets=num_targets + ), + out_shardings=self.batch_sharding, + ) + + output = apply_fn(variables, dummy_x, dummy_num_targets) + self.assertEqual(output.shape, dummy_x.shape) + self.assertEqual(output.sharding, self.batch_sharding) + + grads = jax.grad(loss_fn)( + variables["params"], dummy_x, dummy_num_targets, dropout_key + ) + + grad_leaves, _ = jax.tree_util.tree_flatten(grads) + self.assertNotEmpty(grad_leaves) + for g in grad_leaves: + self.assertFalse(jnp.any(jnp.isnan(g)), "Found NaNs in gradients") + self.assertFalse(jnp.all(g == 0), "Found all-zero gradients") + self.assertEqual(g.sharding, self.replicated_sharding) + + def test_target_invariance(self): + """Tests invariance of output with target section swaps. + + This test checks if swapping items within the target section of sequences + results in an equivalently swapped output. + """ + self.assertEqual(jax.devices()[0].platform, "tpu") + + batch_size, max_len, embedding_dim = 4, 32, 16 + # Adjust batch size to be divisible by the number of devices + if batch_size % self.num_devices != 0: + batch_size = (batch_size // self.num_devices + 1) * self.num_devices + logging.warning("Adjusted batch size to %d for sharding", batch_size) + + config = STULayerConfig( + embedding_dim=embedding_dim, + num_heads=2, + hidden_dim=24, + attention_dim=8, + target_aware=True, + causal=True, + ) + model = STUStack(configs=[config]) + + init_key, data_key = jax.random.split(self.key) + x = jax.random.normal(data_key, (batch_size, max_len, embedding_dim)) + num_targets = jax.random.randint( + data_key, (batch_size,), minval=2, maxval=10 + ) + + # Shard inputs + x = jax.device_put(x, self.batch_sharding) + num_targets = jax.device_put(num_targets, self.batch_sharding) + + swap_from_offset = jnp.zeros((batch_size,), dtype=jnp.int32) + swap_to_offset = jnp.ones((batch_size,), dtype=jnp.int32) + swap_from_offset = jax.device_put(swap_from_offset, self.batch_sharding) + swap_to_offset = jax.device_put(swap_to_offset, self.batch_sharding) + + swap_from_idx = max_len - 1 - swap_from_offset + swap_to_idx = max_len - 1 - swap_to_offset + + variables = model.init( + {"params": init_key, "dropout": data_key}, + x=x, + num_targets=num_targets, + ) + variables = jax.device_put(variables, self.replicated_sharding) + + apply_fn = jax.jit( + lambda v, x, num_targets: model.apply( + v, + x, + num_targets=num_targets, + ), + out_shardings=self.batch_sharding, + ) + + output_original = apply_fn(variables, x, num_targets) + self.assertEqual(output_original.sharding, self.batch_sharding) + + def swap_rows(arr, idx1, idx2): + val1 = arr[idx1] + val2 = arr[idx2] + return arr.at[idx1].set(val2).at[idx2].set(val1) + + swapped_x = jax.vmap(swap_rows)(x, swap_from_idx, swap_to_idx) + self.assertEqual(swapped_x.sharding, self.batch_sharding) + output_swapped_input = apply_fn(variables, swapped_x, num_targets) + self.assertEqual(output_swapped_input.sharding, self.batch_sharding) + + output_swapped_restored = jax.vmap(swap_rows)( + output_swapped_input, swap_from_idx, swap_to_idx + ) + self.assertEqual(output_swapped_restored.sharding, self.batch_sharding) + + np.testing.assert_allclose( + output_original, output_swapped_restored, rtol=1e-2, atol=1e-2 + ) + + def test_kv_caching(self): + """Tests that decode with KV caching is equivalent to running without cache.""" + if jax.devices()[0].platform != 'tpu': + self.skipTest('Test only supported on TPUs.') + + batch_size, max_len, embedding_dim = 2, 8, 16 + config = STULayerConfig( + embedding_dim=embedding_dim, + num_heads=2, + hidden_dim=24, + attention_dim=8, + target_aware=False, + causal=True, + max_decode_length=max_len, + ) + model = STUStack(configs=[config]) + + init_key, data_key, dropout_key = jax.random.split(self.key, 3) + x = jax.random.normal(data_key, (batch_size, max_len, embedding_dim)) + num_targets = jnp.ones((batch_size,), dtype=jnp.int32) + + # Full sequence processing + variables = model.init( + {"params": init_key, "dropout": dropout_key}, + x=x, + num_targets=num_targets, + deterministic=True, + decode=False, + ) + y_ref = model.apply( + variables, + x=x, + num_targets=num_targets, + deterministic=True, + decode=False, + ) + + # Decode step-by-step + decode_variables = model.init( + {'params': init_key, 'dropout': dropout_key}, + x=x[:, 0:1, :], + num_targets=num_targets, + deterministic=True, + decode=True, + ) + + y_decoded_list = [] + current_vars = decode_variables + for i in range(max_len): + y_i, mutated_vars = model.apply( + current_vars, + x[:, i : i + 1, :], + num_targets=num_targets, + deterministic=True, + decode=True, + mutable=['cache'], + ) + current_vars = core.freeze( + {'params': current_vars['params'], **mutated_vars} + ) + y_decoded_list.append(y_i) + + y_decoded = jnp.concatenate(y_decoded_list, axis=1) + + np.testing.assert_allclose(y_decoded, y_ref, rtol=0.25, atol=0.25) + + +if __name__ == "__main__": + absltest.main() + diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..3849425 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -334,7 +334,7 @@ def _to_np(x: Any) -> np.ndarray: weights[key] = np.reshape(weights[key], (-1, 1)) self._batch_number += 1 - csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( features=features, features_weights=weights, feature_specs=self.sparsecore_config.feature_specs, @@ -345,6 +345,7 @@ def _to_np(x: Any) -> np.ndarray: allow_id_dropping=self.sparsecore_config.allow_id_dropping, batch_number=self._batch_number, ) + csr_inputs = preprocessed_inputs.sparse_dense_matmul_input processed_inputs = { k: v for k, v in inputs.items() if k not in sparse_features @@ -362,7 +363,7 @@ class SparsecoreEmbed(nn.Module): Attributes: sparsecore_config: A sparsecore config specifying how to create the tables. mesh: The mesh to use for the embedding layer. If not provided, the global - mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an + mesh set by `jax.set_mesh` will be used. If neither is set, an error will be raised. """ @@ -375,7 +376,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: abstract_mesh = jax.sharding.get_abstract_mesh() if not abstract_mesh.shape_tuple: raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + 'No abstract mesh shape was set with `jax.set_mesh`. Make' ' sure to set the mesh when calling the sparsecore module.' ) return abstract_mesh