- 
                Notifications
    You must be signed in to change notification settings 
- Fork 75
[WIP] LM Workload #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
          
     Draft
      
      
            rka97
  wants to merge
  88
  commits into
  dev
  
    
      
        
          
  
    
      Choose a base branch
      
     
    
      
        
      
      
        
          
          
        
        
          
            
              
              
              
  
           
        
        
          
            
              
              
           
        
       
     
  
        
          
            
          
            
          
        
       
    
      
from
lm_workload
  
      
      
   
  
    
  
  
  
 
  
      
    base: dev
Could not load branches
            
              
  
    Branch not found: {{ refName }}
  
            
                
      Loading
              
            Could not load tags
            
            
              Nothing to show
            
              
  
            
                
      Loading
              
            Are you sure you want to change the base?
            Some commits from the old base branch may be removed from the timeline,
            and old review comments may become outdated.
          
          
  
     Draft
                    [WIP] LM Workload #860
Changes from 84 commits
      Commits
    
    
            Show all changes
          
          
            88 commits
          
        
        Select commit
          Hold shift + click to select a range
      
      1d81455
              
                Merge pull request #847 from mlcommons/dev
              
              
                priyakasimbeg da5f85a
              
                first LM commit
              
              
                Niccolo-Ajroldi a12a364
              
                lm data pipeline
              
              
                Niccolo-Ajroldi ca83ab8
              
                testing
              
              
                Niccolo-Ajroldi e3e78dc
              
                LM workload tested torch pipeline
              
              
                Niccolo-Ajroldi e619495
              
                LM workload - fix torch tests
              
              
                Niccolo-Ajroldi d8e9c56
              
                add LM tests, remove dev files
              
              
                Niccolo-Ajroldi 6b4ff12
              
                add LM tests, remove dev files
              
              
                Niccolo-Ajroldi 3c5c847
              
                Stop tracking .gitignore
              
              
                Niccolo-Ajroldi 20d841b
              
                Remove dev/ from repo, keep locally
              
              
                Niccolo-Ajroldi f3ba059
              
                fix comments
              
              
                Niccolo-Ajroldi 381451f
              
                add class specifications
              
              
                Niccolo-Ajroldi f111d2e
              
                add workload LM info
              
              
                Niccolo-Ajroldi 808d398
              
                restore data_utils.py tree map
              
              
                Niccolo-Ajroldi 35f8f89
              
                fixed NFS bug
              
              
                Niccolo-Ajroldi cbb6ee6
              
                train/val split before concat
              
              
                Niccolo-Ajroldi 868987c
              
                renamed datasets to avoid conflict with HF
              
              
                Niccolo-Ajroldi 8191f6d
              
                Merge remote-tracking branch 'upstream/lm_workload' into lm_workload
              
              
                Niccolo-Ajroldi dd59ded
              
                renamed datasets to dataset
              
              
                Niccolo-Ajroldi 496b9c3
              
                fix style
              
              
                Niccolo-Ajroldi 50989eb
              
                fix formatting
              
              
                Niccolo-Ajroldi 5af0fdc
              
                fix style
              
              
                Niccolo-Ajroldi 2683099
              
                fix style
              
              
                Niccolo-Ajroldi 6b7ee29
              
                fix yapf
              
              
                Niccolo-Ajroldi 46b645b
              
                fix style
              
              
                Niccolo-Ajroldi b3ae647
              
                HF datasets pipeline
              
              
                rka97 f095d4b
              
                Testing with linear model
              
              
                rka97 4189ae0
              
                Merge branch 'jit_switch' into lm_workload
              
              
                rka97 0c22f3d
              
                lm workload with linear model
              
              
                rka97 99c7b9b
              
                add nanodo model
              
              
                rka97 706d9f7
              
                torch model
              
              
                rka97 c335e34
              
                lm workload dataset integration in jax
              
              
                rka97 2d54365
              
                lm workload dataset integration in jax
              
              
                rka97 af8cce4
              
                set package versions for transformers and datasets
              
              
                priyakasimbeg d68c54e
              
                use train_test_split method to shuffle and split fineweb-edu dataset
              
              
                priyakasimbeg 9737367
              
                modifications to fwedu datasetup
              
              
                priyakasimbeg 1bf0750
              
                rename fwedu data dir
              
              
                priyakasimbeg a333391
              
                fix
              
              
                priyakasimbeg 05dc4dd
              
                add back batch mapping in tokenization for fwedu
              
              
                priyakasimbeg b374cf8
              
                debugging
              
              
                priyakasimbeg c0c1e3c
              
                debugging
              
              
                priyakasimbeg f76dc39
              
                debugging
              
              
                priyakasimbeg e805fa7
              
                use tfds to shuffle and split dataset
              
              
                priyakasimbeg 362cbda
              
                Merge remote-tracking branch 'origin/dev' into lm_workload
              
              
                rka97 c9e9abc
              
                add command for fineweb-edu
              
              
                priyakasimbeg e4323de
              
                fix
              
              
                priyakasimbeg f0c6e75
              
                update calls to sharing utils
              
              
                priyakasimbeg f4ffbe7
              
                Fix torch sharding issue, update input pipeline and workload classes …
              
              
                rka97 5c85c7e
              
                test working, lm workload training not working (debugging)
              
              
                rka97 a59dfda
              
                updates to input_pipeline and model spec
              
              
                priyakasimbeg 1c3cb66
              
                add defaults for lm workload
              
              
                priyakasimbeg af91b12
              
                refactor eval pipeline and loss fn for lm
              
              
                priyakasimbeg 6b55adf
              
                refactor evaluation pipeline for lm
              
              
                priyakasimbeg 210d671
              
                remove temporary flag for hlo dumps
              
              
                priyakasimbeg 0ad7788
              
                fix in workload target condition check
              
              
                priyakasimbeg 01921d5
              
                fix in mlp for glu
              
              
                priyakasimbeg e420450
              
                Fix OOM error in weighted cross entropy calculation
              
              
                rka97 3b31ad5
              
                fix issue with checkpointing bool
              
              
                rka97 bbc114f
              
                increase buffer size
              
              
                priyakasimbeg f531b35
              
                Merge branch 'lm_workload_priya' of github.com:mlcommons/algorithmic-…
              
              
                priyakasimbeg 2b162e8
              
                remove _eval_batch from jax workload
              
              
                priyakasimbeg 617e1a3
              
                add todo for pytorch _eval_batch cleanup
              
              
                priyakasimbeg bebc80a
              
                Merge pull request #891 from mlcommons/lm_workload_priya
              
              
                rka97 64ea658
              
                add target setting algorithm for fineweb edu lm workload
              
              
                priyakasimbeg b38ade0
              
                update step hint for lm workload
              
              
                priyakasimbeg 65369f2
              
                update target
              
              
                priyakasimbeg 6171b2d
              
                update eval split sizes for lm workload and target setting point
              
              
                priyakasimbeg d7a885c
              
                Porting workload input pipeline to torch
              
              
                rka97 f111aea
              
                Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
              
              
                rka97 1f0439a
              
                Fix OOM bug in lm eval
              
              
                rka97 b11c193
              
                repeat dataset
              
              
                rka97 42d1d1a
              
                label smoothing default fix
              
              
                priyakasimbeg c334c97
              
                finish merge
              
              
                priyakasimbeg d95f2bf
              
                Make sure to take the correct number of batches in lm
              
              
                rka97 7deb070
              
                Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
              
              
                rka97 0dc16db
              
                Properly handle repetition in LM training and evaluation splits
              
              
                rka97 7edb702
              
                move eval_batch from shared class to framework specific classes since…
              
              
                priyakasimbeg 0879e68
              
                finish merge
              
              
                priyakasimbeg 73e3ea6
              
                Refactor imports and clean up unused code in LM workload and related …
              
              
                rka97 91988af
              
                pass linter checks
              
              
                rka97 bb4a380
              
                Refactor loss function in LM workloads to unify label handling and im…
              
              
                rka97 a58fbd5
              
                Fix init in both models to be the same, add lm model diff test
              
              
                rka97 b59afa0
              
                Refactor model configuration classes to make them consistent between …
              
              
                rka97 d35cdde
              
                Add query-key normalization to CausalAttn and Attention classes, incl…
              
              
                rka97 ffb8163
              
                update target
              
              
                priyakasimbeg 2cc9dff
              
                Merge branch 'lm_workload' of github.com:mlcommons/algorithmic-effici…
              
              
                priyakasimbeg 202e5cb
              
                add pytorch nadamw_target_setting
              
              
                priyakasimbeg 98e491a
              
                docker updates for a100
              
              
                priyakasimbeg File filter
Filter by extension
Conversations
          Failed to load comments.   
        
        
          
      Loading
        
  Jump to
        
          Jump to file
        
      
      
          Failed to load files.   
        
        
          
      Loading
        
  Diff view
Diff view
There are no files selected for viewing
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              
              Empty file.
          
    
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
              | Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| """Input pipeline for a LM dataset.""" | ||
|  | ||
| import functools | ||
| import os | ||
| from typing import Optional | ||
|  | ||
| import jax | ||
| import tensorflow as tf | ||
|  | ||
| from algoperf import data_utils | ||
|  | ||
| AUTOTUNE = tf.data.experimental.AUTOTUNE | ||
| PAD_ID = tf.constant(-1, dtype=tf.int64) | ||
|  | ||
| TFDS_SPLIT_NAME = {'train': 'train', 'eval_train': 'train', 'validation': 'val'} | ||
|  | ||
| SEQUENCE_LENGTH = 1024 | ||
| MAX_CORPUS_CHARS = 1_000_000_000 | ||
| SHUFFLE_BUFFER_SIZE = 1000 | ||
| VOCAB_SIZE = 50_257 | ||
|  | ||
|  | ||
| def batch_with_padding( | ||
| dataset: tf.data.Dataset, | ||
| batch_size, | ||
| padded_shapes=None, | ||
| padding_id=PAD_ID, | ||
| ): | ||
| """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. | ||
|  | ||
| Args: | ||
| dataset: tf.data.Dataset | ||
| batch_size: batch size of resulting batched dataset | ||
| padded_shapes: shapes of the padded batches | ||
| padding_id: value for padding, for elements in new batch | ||
|  | ||
| Returns: | ||
| """ | ||
| batched_dataset = dataset.batch(batch_size, drop_remainder=False) | ||
|  | ||
| # tf.data.Dataset.padded.batch pads elements in the batch so we call it | ||
| # again with batch_size=1 to pad each element in original batch. | ||
| padded_batched_dataset = batched_dataset.padded_batch( | ||
| 1, padded_shapes=padded_shapes, padding_values=padding_id | ||
| ) | ||
|  | ||
| # Remove extra dimension resulting from the batch_size=1. | ||
| padded_batched_dataset = padded_batched_dataset.unbatch() | ||
|  | ||
| return padded_batched_dataset | ||
|  | ||
|  | ||
| def get_data_iter( | ||
| data_rng: jax.random.PRNGKey, | ||
| split: str, | ||
| data_dir: str, | ||
| batch_size: int, | ||
| num_batches: Optional[int] = None, | ||
| ): | ||
| ds = get_lm_dataset(data_rng, split, data_dir, batch_size, num_batches) | ||
|  | ||
| it = map( | ||
| functools.partial( | ||
| data_utils.shard_and_maybe_pad_np, global_batch_size=batch_size | ||
| ), | ||
| ds, | ||
| ) | ||
|  | ||
| return iter(it) | ||
|  | ||
|  | ||
| def get_lm_dataset( | ||
| data_rng: jax.random.PRNGKey, | ||
| split: str, | ||
| data_dir: str, | ||
| batch_size: int, | ||
| num_batches: Optional[int] = None, | ||
| ): | ||
| """Load preprocessed TF dataset.""" | ||
| if split not in TFDS_SPLIT_NAME: | ||
| raise NotImplementedError | ||
|  | ||
| shuffle_seed = jax.random.randint(data_rng, (), -(2**31), 2**31 - 1) | ||
|  | ||
| data_dir = os.path.join(data_dir, TFDS_SPLIT_NAME[split]) | ||
| tokens_ds = tf.data.Dataset.load(data_dir) | ||
|  | ||
| # tokens | ||
| tokens_ds = tokens_ds.flat_map(tf.data.Dataset.from_tensor_slices) | ||
|  | ||
| # sequences | ||
| sequences_ds = tokens_ds.batch(SEQUENCE_LENGTH + 1, drop_remainder=True) | ||
|  | ||
| # get inputs and outputs | ||
| sequences_ds = sequences_ds.map( | ||
| lambda x: { | ||
| 'inputs': x['input_ids'][:SEQUENCE_LENGTH], | ||
| 'targets': x['input_ids'][1:], | ||
| }, | ||
| num_parallel_calls=AUTOTUNE, | ||
| ) | ||
| if split == 'train': | ||
| ds = sequences_ds.shuffle(SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) | ||
| ds = ds.batch(batch_size, drop_remainder=False) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': None, | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| elif split == 'eval_train': | ||
| ds = batch_with_padding( | ||
| sequences_ds, | ||
| batch_size, | ||
| padded_shapes={ | ||
| 'inputs': (batch_size, None), | ||
| 'targets': (batch_size, None), | ||
| }, | ||
| ) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| elif split == 'validation': | ||
| ds = batch_with_padding( | ||
| sequences_ds, | ||
| batch_size, | ||
| padded_shapes={ | ||
| 'inputs': (batch_size, None), | ||
| 'targets': (batch_size, None), | ||
| }, | ||
| ) | ||
| ds = ds.take(num_batches) if num_batches is not None else ds | ||
| ds = ds.repeat() | ||
|         
                  priyakasimbeg marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| ds = ds.map( | ||
| lambda x: { | ||
| 'inputs': x['inputs'], | ||
| 'targets': x['targets'], | ||
| 'weights': tf.where(tf.equal(x['inputs'], PAD_ID), 0.0, 1.0), | ||
| } | ||
| ) | ||
| ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | ||
| return ds | ||
              Empty file.
          
    
      
      Oops, something went wrong.
        
    
  
  Add this suggestion to a batch that can be applied as a single commit.
  This suggestion is invalid because no changes were made to the code.
  Suggestions cannot be applied while the pull request is closed.
  Suggestions cannot be applied while viewing a subset of changes.
  Only one suggestion per line can be applied in a batch.
  Add this suggestion to a batch that can be applied as a single commit.
  Applying suggestions on deleted lines is not supported.
  You must change the existing code in this line in order to create a valid suggestion.
  Outdated suggestions cannot be applied.
  This suggestion has been applied or marked resolved.
  Suggestions cannot be applied from pending reviews.
  Suggestions cannot be applied on multi-line comments.
  Suggestions cannot be applied while the pull request is queued to merge.
  Suggestion cannot be applied right now. Please check back later.
  
    
  
    
Uh oh!
There was an error while loading. Please reload this page.