Skip to content
Draft
Show file tree
Hide file tree
Changes from 84 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
1d81455
Merge pull request #847 from mlcommons/dev
priyakasimbeg Feb 27, 2025
da5f85a
first LM commit
Niccolo-Ajroldi Mar 11, 2025
a12a364
lm data pipeline
Niccolo-Ajroldi Mar 12, 2025
ca83ab8
testing
Niccolo-Ajroldi Mar 14, 2025
e3e78dc
LM workload tested torch pipeline
Niccolo-Ajroldi Mar 17, 2025
e619495
LM workload - fix torch tests
Niccolo-Ajroldi Mar 17, 2025
d8e9c56
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
6b4ff12
add LM tests, remove dev files
Niccolo-Ajroldi Mar 18, 2025
3c5c847
Stop tracking .gitignore
Niccolo-Ajroldi Mar 18, 2025
20d841b
Remove dev/ from repo, keep locally
Niccolo-Ajroldi Mar 18, 2025
f3ba059
fix comments
Niccolo-Ajroldi Mar 18, 2025
381451f
add class specifications
Niccolo-Ajroldi Mar 18, 2025
f111d2e
add workload LM info
Niccolo-Ajroldi Mar 18, 2025
808d398
restore data_utils.py tree map
Niccolo-Ajroldi Mar 18, 2025
35f8f89
fixed NFS bug
Niccolo-Ajroldi Mar 18, 2025
cbb6ee6
train/val split before concat
Niccolo-Ajroldi Mar 18, 2025
868987c
renamed datasets to avoid conflict with HF
Niccolo-Ajroldi Mar 19, 2025
8191f6d
Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
Niccolo-Ajroldi Mar 19, 2025
dd59ded
renamed datasets to dataset
Niccolo-Ajroldi Mar 19, 2025
496b9c3
fix style
Niccolo-Ajroldi Mar 20, 2025
50989eb
fix formatting
Niccolo-Ajroldi Mar 20, 2025
5af0fdc
fix style
Niccolo-Ajroldi Mar 20, 2025
2683099
fix style
Niccolo-Ajroldi Mar 20, 2025
6b7ee29
fix yapf
Niccolo-Ajroldi Mar 20, 2025
46b645b
fix style
Niccolo-Ajroldi Mar 20, 2025
b3ae647
HF datasets pipeline
rka97 Mar 27, 2025
f095d4b
Testing with linear model
rka97 Mar 27, 2025
4189ae0
Merge branch 'jit_switch' into lm_workload
rka97 Mar 27, 2025
0c22f3d
lm workload with linear model
rka97 Apr 3, 2025
99c7b9b
add nanodo model
rka97 Apr 3, 2025
706d9f7
torch model
rka97 Apr 3, 2025
c335e34
lm workload dataset integration in jax
rka97 May 29, 2025
2d54365
lm workload dataset integration in jax
rka97 May 29, 2025
af8cce4
set package versions for transformers and datasets
priyakasimbeg Jun 5, 2025
d68c54e
use train_test_split method to shuffle and split fineweb-edu dataset
priyakasimbeg Jun 5, 2025
9737367
modifications to fwedu datasetup
priyakasimbeg Jun 9, 2025
1bf0750
rename fwedu data dir
priyakasimbeg Jun 9, 2025
a333391
fix
priyakasimbeg Jun 9, 2025
05dc4dd
add back batch mapping in tokenization for fwedu
priyakasimbeg Jun 9, 2025
b374cf8
debugging
priyakasimbeg Jun 10, 2025
c0c1e3c
debugging
priyakasimbeg Jun 10, 2025
f76dc39
debugging
priyakasimbeg Jun 10, 2025
e805fa7
use tfds to shuffle and split dataset
priyakasimbeg Jun 10, 2025
362cbda
Merge remote-tracking branch 'origin/dev' into lm_workload
rka97 Sep 11, 2025
c9e9abc
add command for fineweb-edu
priyakasimbeg Oct 2, 2025
e4323de
fix
priyakasimbeg Oct 2, 2025
f0c6e75
update calls to sharing utils
priyakasimbeg Oct 3, 2025
f4ffbe7
Fix torch sharding issue, update input pipeline and workload classes …
rka97 Oct 6, 2025
5c85c7e
test working, lm workload training not working (debugging)
rka97 Oct 6, 2025
a59dfda
updates to input_pipeline and model spec
priyakasimbeg Oct 6, 2025
1c3cb66
add defaults for lm workload
priyakasimbeg Oct 6, 2025
af91b12
refactor eval pipeline and loss fn for lm
priyakasimbeg Oct 7, 2025
6b55adf
refactor evaluation pipeline for lm
priyakasimbeg Oct 7, 2025
210d671
remove temporary flag for hlo dumps
priyakasimbeg Oct 7, 2025
0ad7788
fix in workload target condition check
priyakasimbeg Oct 7, 2025
01921d5
fix in mlp for glu
priyakasimbeg Oct 8, 2025
e420450
Fix OOM error in weighted cross entropy calculation
rka97 Oct 10, 2025
3b31ad5
fix issue with checkpointing bool
rka97 Oct 10, 2025
bbc114f
increase buffer size
priyakasimbeg Oct 10, 2025
f531b35
Merge branch 'lm_workload_priya' of github.com:mlcommons/algorithmic-…
priyakasimbeg Oct 10, 2025
2b162e8
remove _eval_batch from jax workload
priyakasimbeg Oct 10, 2025
617e1a3
add todo for pytorch _eval_batch cleanup
priyakasimbeg Oct 10, 2025
bebc80a
Merge pull request #891 from mlcommons/lm_workload_priya
rka97 Oct 15, 2025
64ea658
add target setting algorithm for fineweb edu lm workload
priyakasimbeg Oct 16, 2025
b38ade0
update step hint for lm workload
priyakasimbeg Oct 16, 2025
65369f2
update target
priyakasimbeg Oct 16, 2025
6171b2d
update eval split sizes for lm workload and target setting point
priyakasimbeg Oct 16, 2025
d7a885c
Porting workload input pipeline to torch
rka97 Oct 17, 2025
f111aea
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
rka97 Oct 17, 2025
1f0439a
Fix OOM bug in lm eval
rka97 Oct 18, 2025
b11c193
repeat dataset
rka97 Oct 18, 2025
42d1d1a
label smoothing default fix
priyakasimbeg Oct 20, 2025
c334c97
finish merge
priyakasimbeg Oct 20, 2025
d95f2bf
Make sure to take the correct number of batches in lm
rka97 Oct 21, 2025
7deb070
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
rka97 Oct 21, 2025
0dc16db
Properly handle repetition in LM training and evaluation splits
rka97 Oct 21, 2025
7edb702
move eval_batch from shared class to framework specific classes since…
priyakasimbeg Oct 21, 2025
0879e68
finish merge
priyakasimbeg Oct 21, 2025
73e3ea6
Refactor imports and clean up unused code in LM workload and related …
rka97 Oct 21, 2025
91988af
pass linter checks
rka97 Oct 21, 2025
bb4a380
Refactor loss function in LM workloads to unify label handling and im…
rka97 Oct 21, 2025
a58fbd5
Fix init in both models to be the same, add lm model diff test
rka97 Oct 21, 2025
b59afa0
Refactor model configuration classes to make them consistent between …
rka97 Oct 21, 2025
d35cdde
Add query-key normalization to CausalAttn and Attention classes, incl…
rka97 Oct 23, 2025
ffb8163
update target
priyakasimbeg Oct 24, 2025
2cc9dff
Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
priyakasimbeg Oct 24, 2025
202e5cb
add pytorch nadamw_target_setting
priyakasimbeg Oct 26, 2025
98e491a
docker updates for a100
priyakasimbeg Oct 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ scoring/plots/
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_0/eval_measurements.csv
!scoring/test_data/experiment_dir/study_0/mnist_jax/trial_1/eval_measurements.csv

algoperf/_version.py
algoperf/_version.py
49 changes: 48 additions & 1 deletion algoperf/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
"""

import os
from typing import Sequence, Tuple
from typing import Optional, Sequence, Tuple

import numpy as np
import orbax.checkpoint as ocp
import torch
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
from orbax.checkpoint.type_handlers import NumpyHandler
from tensorflow.io import gfile # pytype: disable=import-error

from algoperf import spec
Expand All @@ -30,6 +32,51 @@
]


class BoolHandler(NumpyHandler):
"""
An implementation of TypeHandler for np.bool_ that inherits from NumpyHandler.
It works by treating the scalar as a 0-dimensional array.
"""

def typestr(self) -> str:
"""Unique string identifier for this handler."""
return 'np.bool_'

async def serialize(
self,
values: Sequence[np.bool_],
infos: Sequence,
args: Optional[Sequence[ocp.SaveArgs]] = None,
):
"""
Serializes a sequence of np.bool_ scalars by first converting them
to 0-dim numpy arrays and then calling the parent NumpyHandler.
"""
# Convert each scalar np.bool_ to a 0-dimensional np.ndarray
array_values = [np.asarray(v, dtype=np.bool_) for v in values]
# Use the parent class's robust serialization logic
return await super().serialize(array_values, infos, args)

async def deserialize(
self,
infos: Sequence,
args: Optional[Sequence[ocp.RestoreArgs]] = None,
) -> Sequence[np.bool_]:
"""
Deserializes into a sequence of np.bool_ scalars by calling the
parent handler and then converting the resulting 0-dim arrays.
"""
# Parent deserialize will return a sequence of 0-dimensional np.ndarray
results = await super().deserialize(infos, args)

# Convert each 0-d array back to an np.bool_ scalar using .item()
scalar_results = [np.bool_(r.item()) for r in results]
return scalar_results


ocp.type_handlers.register_type_handler(np.bool_, BoolHandler(), override=True)


def maybe_restore_checkpoint(
framework: str,
optimizer_state: spec.OptimizerState,
Expand Down
2 changes: 2 additions & 0 deletions algoperf/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def pytorch_param_types(
param_types[name] = spec.ParameterType.ATTENTION_BIAS
elif 'in_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'qkv' in name:
param_types[name] = spec.ParameterType.ATTENTION_QKV
elif 'kv_proj' in name:
param_types[name] = spec.ParameterType.ATTENTION_KV
elif 'k_proj' in name or 'key' in name:
Expand Down
6 changes: 4 additions & 2 deletions algoperf/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
return use_pytorch_ddp, rank, device, n_gpus


def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:
def pytorch_init(
use_pytorch_ddp: bool, rank: int, profiler: Profiler, limit_tf_threads=True
) -> None:
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# Only use CPU for Jax to avoid memory issues.
Expand All @@ -39,7 +41,7 @@ def pytorch_init(use_pytorch_ddp: bool, rank: int, profiler: Profiler) -> None:

if use_pytorch_ddp:
# Avoid tf input pipeline creating too many threads.
if rank != 0:
if rank != 0 and limit_tf_threads:
tf.config.threading.set_intra_op_parallelism_threads(1)
tf.config.threading.set_inter_op_parallelism_threads(1)

Expand Down
4 changes: 2 additions & 2 deletions algoperf/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:

def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down
Empty file.
153 changes: 153 additions & 0 deletions algoperf/workloads/lm/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Input pipeline for a LM dataset."""

import functools
import os
from typing import Optional

import jax
import tensorflow as tf

from algoperf import data_utils

AUTOTUNE = tf.data.experimental.AUTOTUNE
PAD_ID = tf.constant(-1, dtype=tf.int64)

TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'}

SEQUENCE_LENGTH = 1024
MAX_CORPUS_CHARS = 1_000_000_000
SHUFFLE_BUFFER_SIZE = 1000
VOCAB_SIZE = 50_257


def batch_with_padding(
dataset: tf.data.Dataset,
batch_size,
padded_shapes=None,
padding_id=PAD_ID,
):
"""Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size.

Args:
dataset: tf.data.Dataset
batch_size: batch size of resulting batched dataset
padded_shapes: shapes of the padded batches
padding_id: value for padding, for elements in new batch

Returns:
"""
batched_dataset = dataset.batch(batch_size, drop_remainder=False)

# tf.data.Dataset.padded.batch pads elements in the batch so we call it
# again with batch_size=1 to pad each element in original batch.
padded_batched_dataset = batched_dataset.padded_batch(
1, padded_shapes=padded_shapes, padding_values=padding_id
)

# Remove extra dimension resulting from the batch_size=1.
padded_batched_dataset = padded_batched_dataset.unbatch()

return padded_batched_dataset


def get_data_iter(
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
batch_size: int,
num_batches: Optional[int] = None,
):
ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches)

it = map(
functools.partial(
data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size
),
ds,
)

return iter(it)


def get_lm_dataset(
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
batch_size: int,
num_batches: Optional[int] = None,
):
"""Load preprocessed TF dataset."""
if split not in TFDS_SPLIT_NAME:
raise NotImplementedError

shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1)

data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split])
tokens_ds = tf.data.Dataset.load(data_dir)

# tokens
tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices)

# sequences
sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True)

# get inputs and outputs
sequences_ds = sequences_ds.map(
lambda x: {
'inputs': x['input_ids'][:SEQUENCE_LENGTH],
'targets': x['input_ids'][1:],
},
num_parallel_calls=AUTOTUNE,
)
if split == 'train':
ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed)
ds = ds.batch(batch_size, drop_remainder=False)
ds = ds.take(num_batches) if num_batches is not None else ds
ds = ds.repeat()
ds = ds.map(
lambda x: {
'inputs': x['inputs'],
'targets': x['targets'],
'weights': None,
}
)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
elif split == 'eval_train':
ds = batch_with_padding(
sequences_ds,
batch_size,
padded_shapes={
'inputs': (batch_size, None),
'targets': (batch_size, None),
},
)
ds = ds.take(num_batches) if num_batches is not None else ds
ds = ds.repeat()
ds = ds.map(
lambda x: {
'inputs': x['inputs'],
'targets': x['targets'],
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0),
}
)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
elif split == 'validation':
ds = batch_with_padding(
sequences_ds,
batch_size,
padded_shapes={
'inputs': (batch_size, None),
'targets': (batch_size, None),
},
)
ds = ds.take(num_batches) if num_batches is not None else ds
ds = ds.repeat()
ds = ds.map(
lambda x: {
'inputs': x['inputs'],
'targets': x['targets'],
'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0),
}
)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
Empty file.
Loading