Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions recml/core/data/tf_dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
from typing import Any, Protocol

from absl import flags
from absl import logging
import jax
from recml.core.utils import types
Expand Down Expand Up @@ -162,12 +163,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
Defaults to False.
seed: An optional seed to use for deterministic shuffling / preprocessing.
Defaults to None.
tf_data_service_address: An optional URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
data service will be applied.
enable_tf_data_service: Whether to apply tf.data service for this dataset.
If True, flag `tf_data_service_address` must be set.
tf_data_service_policy: Sharding policy to use for tf.data service when it
is enabled.
tf_data_service_job_name: Job name to use for tf.data service. If None, the
default job name will be used.
feature_spec: A mapping of feature keys to `FixedLenFeature`,
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
used to parse the TF examples, or as context_features spec to parse TF
Expand Down Expand Up @@ -208,7 +209,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
tensorflow.
debug: An optional boolean indicating whether to debug input boundedness. If
`True`, the dataset will consist of a single batch that's cached and
infinitely repeated
infinitely repeated.
"""

cache_reading: bool = False
Expand All @@ -231,7 +232,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
readahead: str | None = None
group_uris_by_dir: bool = False
seed: int | None = None
tf_data_service_address: str | None = None
enable_tf_data_service: bool = False
tf_data_service_job_name: str | None = None
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
tf.data.experimental.service.ShardingPolicy.OFF
)
Expand All @@ -249,7 +251,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
debug: bool = False

def __post_init__(self):
if self.tf_data_service_address is not None:
if self.enable_tf_data_service:
if flags.FLAGS.tf_data_service_address is None:
raise ValueError(
"Flag `tf_data_service_address` must be set when"
" `enable_tf_data_service` is True."
)
if self.seed is not None:
raise ValueError("`seed` must be None for data service.")
if self.sharding:
Expand Down Expand Up @@ -533,23 +540,26 @@ def _maybe_apply_tf_data_service(
self, dataset: tf.data.Dataset
) -> tf.data.Dataset:
"""Applies the tf.data service to the dataset."""
if self.tf_data_service_address is None:
if not self.enable_tf_data_service:
return dataset

tf_data_service_address = flags.FLAGS.tf_data_service_address

per_proc_batch_size = self.sharding_info.per_process_batch_size(
self.global_batch_size
)
logging.info(
"Applying tf.data service with address %s and per replica batch"
" size %s",
self.tf_data_service_address,
tf_data_service_address,
per_proc_batch_size,
)
return dataset.apply(
tf.data.experimental.service.distribute(
processing_mode=self.tf_data_service_policy,
service=self.tf_data_service_address,
job_name=f"bs_{per_proc_batch_size}",
service=tf_data_service_address,
job_name=self.tf_data_service_job_name
or "tf_data_service_shared_job_name",
)
)

Expand Down
16 changes: 8 additions & 8 deletions recml/core/ops/hstu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def _apply_mask(
masks = []
if mask_ref is not None:
if k_in_lanes:
mask = pl.load(mask_ref, (slice(None), k_slice))
mask = mask_ref[:, k_slice]
else:
mask = pl.load(mask_ref, (k_slice, slice(None)))
mask = mask_ref[k_slice, :]

snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _apply_mask(
k_sequence = k_offset + jax.lax.broadcasted_iota(
jnp.int32, (k_slice.size, bq), 0
)
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
q_sequence = q_sequence_ref[:1, :] # [1, bq]
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))

assert q_sequence.shape == k_sequence.shape
Expand All @@ -170,7 +170,7 @@ def _apply_mask(

if q_segment_ids_ref is not None:
if k_in_lanes:
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
if rem:
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
Expand All @@ -181,9 +181,9 @@ def _apply_mask(
if rem:
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
kv_ids = pltpu.repeat(
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
kv_segment_ids_ref[k_slice, :], repeats, axis=1
) # [k_slice, bq]
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
masks.append(q_ids == kv_ids)

if masks:
Expand Down Expand Up @@ -228,7 +228,7 @@ def body(kv_compute_index, _):
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)

q = q_ref[...]
k = pl.load(k_ref, (slice_k, slice(None)))
k = k_ref[slice_k, :]
qk = jax.lax.dot_general(
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
)
Expand Down Expand Up @@ -256,7 +256,7 @@ def body(kv_compute_index, _):
)

sv_dims = NN_DIM_NUMBERS
v = pl.load(v_ref, (slice_k, slice(None)))
v = v_ref[slice_k, :]

to_float32 = lambda x: x.astype(jnp.float32)
v = to_float32(v)
Expand Down
122 changes: 74 additions & 48 deletions recml/core/training/keras_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import abc
from collections.abc import Mapping
import dataclasses
import functools
import gc
import os
import time
Expand Down Expand Up @@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
model: The Keras model constructed by `create_model`.
model_dir: The model directory passed to the trainer.
"""
model.save(os.path.join(model_dir, core.KERAS_MODEL_SAVEFILE))


class KerasTrainer(core.Trainer[KerasTask]):
Expand All @@ -118,6 +118,7 @@ def __init__(
max_checkpoints_to_keep: int = 5,
checkpoint_save_interval_epochs: int = 1,
rng_seed: int = core.DEFAULT_RNG_SEED,
legacy_checkpoint_format: bool = True,
):
"""Initializes the instance."""

Expand All @@ -143,60 +144,77 @@ def __init__(
self._steps_per_eval = steps_per_eval
self._continuous_eval_timeout = continuous_eval_timeout
self._steps_per_loop = steps_per_loop
self._checkpoint_manager = None
self._marker_path = os.path.join(
model_dir, core.TRAINING_COMPLETE_MARKER_FILE
)
self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR)
self._max_checkpoints_to_keep = max_checkpoints_to_keep
self._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
self._legacy_checkpoint_format = legacy_checkpoint_format

@functools.cached_property
def train_callbacks(self) -> list[keras.callbacks.Callback]:
"""Returns the training callbacks."""
if keras.backend.backend() == "jax":
self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=max_checkpoints_to_keep,
save_interval_epochs=checkpoint_save_interval_epochs,
)
self._train_callbacks = [
if self._legacy_checkpoint_format:
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=self._max_checkpoints_to_keep,
save_interval_epochs=self._checkpoint_save_interval_epochs,
)
else:
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=self._max_checkpoints_to_keep,
save_interval_epochs=self._checkpoint_save_interval_epochs,
)
return [
keras_utils.EpochSummaryCallback(
log_dir=os.path.join(model_dir, core.LOG_DIR),
steps_per_epoch=steps_per_loop,
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
steps_per_epoch=self._steps_per_loop,
write_steps_per_second=True,
),
keras_utils.EpochOrbaxCheckpointAndRestoreCallback(
checkpoint_manager=self._checkpoint_manager,
checkpoint_manager=checkpoint_manager,
marker_path=self._marker_path,
),
]
self._eval_callbacks = [
return [
keras.callbacks.TensorBoard(
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
keras.callbacks.BackupAndRestore(
backup_dir=os.path.join(self._model_dir, core.BACKUP_DIR),
),
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
self._model_dir,
core.CHECKPOINT_DIR,
"ckpt-{epoch:d}.weights.h5",
),
save_weights_only=True,
verbose=1,
),
]

@functools.cached_property
def eval_callbacks(self) -> list[keras.callbacks.Callback]:
"""Returns the evaluation callbacks."""
if keras.backend.backend() == "jax":
return [
keras_utils.EpochSummaryCallback(
log_dir=os.path.join(model_dir, core.LOG_DIR),
steps_per_epoch=steps_per_loop,
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
steps_per_epoch=self._steps_per_loop,
write_steps_per_second=False,
),
]
else:
self._checkpoint_manager = None
self._train_callbacks = [
keras.callbacks.TensorBoard(
log_dir=os.path.join(model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
keras.callbacks.BackupAndRestore(
backup_dir=os.path.join(model_dir, core.BACKUP_DIR),
),
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
model_dir, core.CHECKPOINT_DIR, "ckpt-{epoch:d}.weights.h5"
),
save_weights_only=True,
verbose=1,
),
]
self._eval_callbacks = [
keras.callbacks.TensorBoard(
log_dir=os.path.join(model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
]
return [
keras.callbacks.TensorBoard(
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
]

def _maybe_get_model_kws(
self, task: KerasTask, dataset: tf.data.Dataset
Expand All @@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
dataset,
epochs=self._train_epochs,
steps_per_epoch=self._steps_per_loop,
callbacks=self._train_callbacks,
callbacks=self.train_callbacks,
)
model.summary(print_fn=logging.info)

Expand All @@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
if keras.backend.backend() == "jax":
[tb_cbk] = [
cbk
for cbk in self._eval_callbacks
for cbk in self.eval_callbacks
if isinstance(cbk, keras_utils.EpochSummaryCallback)
]
epoch_start_time = time.time()
history = model.evaluate(
dataset,
steps=self._steps_per_eval,
callbacks=self._eval_callbacks,
callbacks=self.eval_callbacks,
return_dict=True,
)
epoch_dt = time.time() - epoch_start_time
Expand All @@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
return model.evaluate(
dataset,
steps=self._steps_per_eval,
callbacks=self._eval_callbacks,
callbacks=self.eval_callbacks,
)

def train_and_evaluate(self, task: KerasTask) -> core.Logs:
Expand All @@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
steps_per_epoch=self._steps_per_loop,
# Explicitly set to None for deterministic evaluation.
validation_steps=None,
callbacks=self._train_callbacks,
callbacks=self.train_callbacks,
)
model.summary(print_fn=logging.info)

Expand Down Expand Up @@ -308,7 +326,10 @@ def timeout_fn() -> bool:
else:
steps_msg = "running complete evaluation..."

use_legacy_checkpoint_format = self._legacy_checkpoint_format

class _RestoreCallback(keras.callbacks.Callback):
"""Callback for restoring the model from the latest checkpoint."""

def __init__(
self,
Expand All @@ -319,9 +340,14 @@ def __init__(
self._epoch = epoch

def on_test_begin(self, logs: Mapping[str, Any] | None = None):
keras_utils.restore_keras_model(
model, self._checkpoint_dir, step=self._epoch
)
if use_legacy_checkpoint_format:
keras_utils.restore_keras_model(
model, self._checkpoint_dir, step=self._epoch
)
else:
keras_utils.restore_keras_checkpoint(
self._checkpoint_dir, model=model, epoch=self._epoch
)

history = None
for epoch in ocp.checkpoint_utils.checkpoints_iterator(
Expand All @@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
restore_callback = _RestoreCallback(self._checkpoint_dir, epoch)
[tb_cbk] = [
cbk
for cbk in self._eval_callbacks
for cbk in self.eval_callbacks
if isinstance(cbk, keras_utils.EpochSummaryCallback)
]
try:
Expand All @@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
history = model.evaluate(
eval_dataset,
steps=self._steps_per_eval,
callbacks=[restore_callback] + self._eval_callbacks,
callbacks=[restore_callback] + self.eval_callbacks,
return_dict=True,
)

Expand Down
Loading
Loading