Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions recml/core/ops/hstu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def _apply_mask(
masks = []
if mask_ref is not None:
if k_in_lanes:
mask = pl.load(mask_ref, (slice(None), k_slice))
mask = mask_ref[:, k_slice]
else:
mask = pl.load(mask_ref, (k_slice, slice(None)))
mask = mask_ref[k_slice, :]

snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _apply_mask(
k_sequence = k_offset + jax.lax.broadcasted_iota(
jnp.int32, (k_slice.size, bq), 0
)
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
q_sequence = q_sequence_ref[:1, :] # [1, bq]
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))

assert q_sequence.shape == k_sequence.shape
Expand All @@ -170,7 +170,7 @@ def _apply_mask(

if q_segment_ids_ref is not None:
if k_in_lanes:
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
if rem:
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
Expand All @@ -181,9 +181,9 @@ def _apply_mask(
if rem:
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
kv_ids = pltpu.repeat(
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
kv_segment_ids_ref[k_slice, :], repeats, axis=1
) # [k_slice, bq]
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
masks.append(q_ids == kv_ids)

if masks:
Expand Down Expand Up @@ -228,7 +228,7 @@ def body(kv_compute_index, _):
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)

q = q_ref[...]
k = pl.load(k_ref, (slice_k, slice(None)))
k = k_ref[slice_k, :]
qk = jax.lax.dot_general(
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
)
Expand Down Expand Up @@ -256,7 +256,7 @@ def body(kv_compute_index, _):
)

sv_dims = NN_DIM_NUMBERS
v = pl.load(v_ref, (slice_k, slice(None)))
v = v_ref[slice_k, :]

to_float32 = lambda x: x.astype(jnp.float32)
v = to_float32(v)
Expand Down
122 changes: 74 additions & 48 deletions recml/core/training/keras_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import abc
from collections.abc import Mapping
import dataclasses
import functools
import gc
import os
import time
Expand Down Expand Up @@ -96,7 +97,6 @@ def export_model(self, model: keras.Model, model_dir: str):
model: The Keras model constructed by `create_model`.
model_dir: The model directory passed to the trainer.
"""
model.save(os.path.join(model_dir, core.KERAS_MODEL_SAVEFILE))


class KerasTrainer(core.Trainer[KerasTask]):
Expand All @@ -118,6 +118,7 @@ def __init__(
max_checkpoints_to_keep: int = 5,
checkpoint_save_interval_epochs: int = 1,
rng_seed: int = core.DEFAULT_RNG_SEED,
legacy_checkpoint_format: bool = True,
):
"""Initializes the instance."""

Expand All @@ -143,60 +144,77 @@ def __init__(
self._steps_per_eval = steps_per_eval
self._continuous_eval_timeout = continuous_eval_timeout
self._steps_per_loop = steps_per_loop
self._checkpoint_manager = None
self._marker_path = os.path.join(
model_dir, core.TRAINING_COMPLETE_MARKER_FILE
)
self._checkpoint_dir = os.path.join(model_dir, core.CHECKPOINT_DIR)
self._max_checkpoints_to_keep = max_checkpoints_to_keep
self._checkpoint_save_interval_epochs = checkpoint_save_interval_epochs
self._legacy_checkpoint_format = legacy_checkpoint_format

@functools.cached_property
def train_callbacks(self) -> list[keras.callbacks.Callback]:
"""Returns the training callbacks."""
if keras.backend.backend() == "jax":
self._checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=max_checkpoints_to_keep,
save_interval_epochs=checkpoint_save_interval_epochs,
)
self._train_callbacks = [
if self._legacy_checkpoint_format:
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManager(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=self._max_checkpoints_to_keep,
save_interval_epochs=self._checkpoint_save_interval_epochs,
)
else:
checkpoint_manager = keras_utils.KerasOrbaxCheckpointManagerV2(
checkpoint_dir=self._checkpoint_dir,
max_to_keep=self._max_checkpoints_to_keep,
save_interval_epochs=self._checkpoint_save_interval_epochs,
)
return [
keras_utils.EpochSummaryCallback(
log_dir=os.path.join(model_dir, core.LOG_DIR),
steps_per_epoch=steps_per_loop,
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
steps_per_epoch=self._steps_per_loop,
write_steps_per_second=True,
),
keras_utils.EpochOrbaxCheckpointAndRestoreCallback(
checkpoint_manager=self._checkpoint_manager,
checkpoint_manager=checkpoint_manager,
marker_path=self._marker_path,
),
]
self._eval_callbacks = [
return [
keras.callbacks.TensorBoard(
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
keras.callbacks.BackupAndRestore(
backup_dir=os.path.join(self._model_dir, core.BACKUP_DIR),
),
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
self._model_dir,
core.CHECKPOINT_DIR,
"ckpt-{epoch:d}.weights.h5",
),
save_weights_only=True,
verbose=1,
),
]

@functools.cached_property
def eval_callbacks(self) -> list[keras.callbacks.Callback]:
"""Returns the evaluation callbacks."""
if keras.backend.backend() == "jax":
return [
keras_utils.EpochSummaryCallback(
log_dir=os.path.join(model_dir, core.LOG_DIR),
steps_per_epoch=steps_per_loop,
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
steps_per_epoch=self._steps_per_loop,
write_steps_per_second=False,
),
]
else:
self._checkpoint_manager = None
self._train_callbacks = [
keras.callbacks.TensorBoard(
log_dir=os.path.join(model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
keras.callbacks.BackupAndRestore(
backup_dir=os.path.join(model_dir, core.BACKUP_DIR),
),
keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
model_dir, core.CHECKPOINT_DIR, "ckpt-{epoch:d}.weights.h5"
),
save_weights_only=True,
verbose=1,
),
]
self._eval_callbacks = [
keras.callbacks.TensorBoard(
log_dir=os.path.join(model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
]
return [
keras.callbacks.TensorBoard(
log_dir=os.path.join(self._model_dir, core.LOG_DIR),
write_steps_per_second=True,
),
]

def _maybe_get_model_kws(
self, task: KerasTask, dataset: tf.data.Dataset
Expand All @@ -218,7 +236,7 @@ def train(self, task: KerasTask) -> core.Logs:
dataset,
epochs=self._train_epochs,
steps_per_epoch=self._steps_per_loop,
callbacks=self._train_callbacks,
callbacks=self.train_callbacks,
)
model.summary(print_fn=logging.info)

Expand All @@ -237,14 +255,14 @@ def evaluate(self, task: KerasTask) -> core.Logs:
if keras.backend.backend() == "jax":
[tb_cbk] = [
cbk
for cbk in self._eval_callbacks
for cbk in self.eval_callbacks
if isinstance(cbk, keras_utils.EpochSummaryCallback)
]
epoch_start_time = time.time()
history = model.evaluate(
dataset,
steps=self._steps_per_eval,
callbacks=self._eval_callbacks,
callbacks=self.eval_callbacks,
return_dict=True,
)
epoch_dt = time.time() - epoch_start_time
Expand All @@ -257,7 +275,7 @@ def evaluate(self, task: KerasTask) -> core.Logs:
return model.evaluate(
dataset,
steps=self._steps_per_eval,
callbacks=self._eval_callbacks,
callbacks=self.eval_callbacks,
)

def train_and_evaluate(self, task: KerasTask) -> core.Logs:
Expand All @@ -277,7 +295,7 @@ def train_and_evaluate(self, task: KerasTask) -> core.Logs:
steps_per_epoch=self._steps_per_loop,
# Explicitly set to None for deterministic evaluation.
validation_steps=None,
callbacks=self._train_callbacks,
callbacks=self.train_callbacks,
)
model.summary(print_fn=logging.info)

Expand Down Expand Up @@ -308,7 +326,10 @@ def timeout_fn() -> bool:
else:
steps_msg = "running complete evaluation..."

use_legacy_checkpoint_format = self._legacy_checkpoint_format

class _RestoreCallback(keras.callbacks.Callback):
"""Callback for restoring the model from the latest checkpoint."""

def __init__(
self,
Expand All @@ -319,9 +340,14 @@ def __init__(
self._epoch = epoch

def on_test_begin(self, logs: Mapping[str, Any] | None = None):
keras_utils.restore_keras_model(
model, self._checkpoint_dir, step=self._epoch
)
if use_legacy_checkpoint_format:
keras_utils.restore_keras_model(
model, self._checkpoint_dir, step=self._epoch
)
else:
keras_utils.restore_keras_checkpoint(
self._checkpoint_dir, model=model, epoch=self._epoch
)

history = None
for epoch in ocp.checkpoint_utils.checkpoints_iterator(
Expand All @@ -332,7 +358,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
restore_callback = _RestoreCallback(self._checkpoint_dir, epoch)
[tb_cbk] = [
cbk
for cbk in self._eval_callbacks
for cbk in self.eval_callbacks
if isinstance(cbk, keras_utils.EpochSummaryCallback)
]
try:
Expand All @@ -346,7 +372,7 @@ def on_test_begin(self, logs: Mapping[str, Any] | None = None):
history = model.evaluate(
eval_dataset,
steps=self._steps_per_eval,
callbacks=[restore_callback] + self._eval_callbacks,
callbacks=[restore_callback] + self.eval_callbacks,
return_dict=True,
)

Expand Down
17 changes: 15 additions & 2 deletions recml/core/training/keras_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,23 @@ def setUp(self):
"mode": core.Experiment.Mode.TRAIN_AND_EVAL,
},
{
"testcase_name": "continuous_eval",
"testcase_name": "continuous_eval_",
"mode": core.Experiment.Mode.CONTINUOUS_EVAL,
},
{
"testcase_name": "train_and_eval_legacy_checkpoint_format",
"mode": core.Experiment.Mode.TRAIN_AND_EVAL,
"legacy_checkpoint_format": True,
},
{
"testcase_name": "continuous_eval_legacy_checkpoint_format",
"mode": core.Experiment.Mode.CONTINUOUS_EVAL,
"legacy_checkpoint_format": True,
},
)
def test_keras_task_and_trainer(self, mode: str):
def test_keras_task_and_trainer(
self, mode: str, legacy_checkpoint_format: bool = False
):
if keras.backend.backend() == "jax":
distribution = keras.distribution.DataParallel()
else:
Expand All @@ -78,6 +90,7 @@ def test_keras_task_and_trainer(self, mode: str):
steps_per_loop=2,
model_dir=self.create_tempdir().full_path,
continuous_eval_timeout=5,
legacy_checkpoint_format=legacy_checkpoint_format,
)
experiment = core.Experiment(_KerasTask(), trainer)

Expand Down
10 changes: 5 additions & 5 deletions recml/core/training/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array:
def partition_init(
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
) -> CreateStateFn:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
if abstract_batch is not None:
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)
Expand All @@ -117,7 +117,7 @@ def partition_init(
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)

def _wrapped_init(batch: PyTree) -> State:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
state = init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -130,15 +130,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
jit_kws["out_shardings"] = (self.state_sharding, None)
jit_kws["donate_argnums"] = (1,)

with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
**jit_kws,
)

def _wrapped_step(batch: PyTree, state: State) -> Any:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
return step_fn(batch, state)

return _wrapped_step
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
def mesh_context_manager(
self,
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
return jax.sharding.use_mesh
return jax.set_mesh

def shard_inputs(self, inputs: PyTree) -> PyTree:
def _shard(x: np.ndarray) -> jax.Array:
Expand Down
Loading
Loading