Skip to content
2 changes: 2 additions & 0 deletions train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ Location of the checkpoints of the trained models plus logs and anything else of

- tr9b-350M-swiglu: `six_ALL_CCFRSTORE/checkpoints/tr9b-350M-swiglu`
- tr9c-1B3-swiglu-pile: `six_ALL_CCFRSTORE/checkpoints/tr9b-1B3-swiglu-pile`

- tr13: Multi-Task Fine-tuning (T0)
197 changes: 197 additions & 0 deletions train/tr13-t0/t0_test.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#!/bin/bash
#SBATCH --job-name=tr11e-350M-ml-t0
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=40 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --gres=gpu:1 # number of gpus
#SBATCH -C v100-32g
#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=%x-%j.out # output file name
#SBATCH --account=six@v100

set -x -e

source $six_ALL_CCFRWORK/start-muennighofflmeval
echo "START TIME: $(date)"

variant=main

DATA_OUTPUT_PATH=$six_ALL_CCFRSCRATCH/checkpoints/tr13-test-ml
CHECKPOINT_PATH=$DATA_OUTPUT_PATH/checkpoints/$variant
REPO_PATH=$DATA_OUTPUT_PATH/tr11e-350M-ml-logs
TENSORBOARD_PATH=$REPO_PATH/tensorboard-test/$variant
LOGS_PATH=$REPO_PATH/logs-test/$variant
mkdir -p $LOGS_PATH

MEGATRON_DEEPSPEED_REPO=/gpfsscratch/rech/six/commun/commun/experiments/muennighoff/megdsmtf/thomas2/Megatron-DeepSpeed
cd $MEGATRON_DEEPSPEED_REPO

BIGSCIENCE_REPO=$six_ALL_CCFRWORK/code/bigscience
TRAIN_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/p3_train.txt
VALID_DATA_PATH=$MEGATRON_DEEPSPEED_REPO/p3_validation.txt
TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles

# defining the right environment variables
export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models
export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasets
export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules
export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

# testing for potential faulty nodes
# srun --jobid $SLURM_JOBID bash -c 'python -c "import torch, socket; print(socket.gethostname(), torch.cuda.is_available())"'

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6001

GPUS_PER_NODE=1
NNODES=1

PP_SIZE=1
TP_SIZE=1

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=4

NLAYERS=2
NHIDDEN=1024
NHEADS=16
SEQ_LEN=256

SAVE_INTERVAL=250

TRAIN_SAMPLES=10 # TODO
LR_DECAY_SAMPLES=10 # TODO
LR_WARMUP_SAMPLES=1 # TODO


OPTIMIZER_ARGS=" \
--optimizer adam \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--adam-eps 1e-8 \
--lr 3.0e-4 \
--min-lr 1e-5 \
--lr-decay-style cosine \
--lr-decay-samples $LR_DECAY_SAMPLES \
--lr-warmup-samples $LR_WARMUP_SAMPLES \
--clip-grad 1.0 \
--weight-decay 1e-1 \
"
# for 20h 1190, for 100h 5990
# --exit-duration-in-mins 1190 \
EXIT_OPTS=" \
--exit-duration-in-mins 5990 \
"

GPT_ARGS=" \
--pp-partition-method 'type:transformer|embedding' \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads $NHEADS \
--seq-length $SEQ_LEN \
--max-position-embeddings $SEQ_LEN \
--micro-batch-size $MICRO_BATCH_SIZE \
--global-batch-size $GLOBAL_BATCH_SIZE \
--train-samples $TRAIN_SAMPLES \
--tokenizer-type PretrainedFromHF \
--tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \
--init-method-std 0.0048 \
--embed-layernorm \
--fp16 \
--seed 42 \
--position-embedding-type alibi \
--abort-on-unmet-fused-kernel-constraints \
--pad-vocab-size-to 250880 \
$OPTIMIZER_ARGS \
$EXIT_OPTS \
"

OUTPUT_ARGS=" \
--log-interval 1 \
--save-interval $SAVE_INTERVAL \
--eval-interval 1000 \
--eval-iters 1 \
--tensorboard-dir $TENSORBOARD_PATH \
--tensorboard-queue-size 5 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
"

ZERO_STAGE=0 # important: bf16 must use z0! it implements its own zero stage 1 equivalent

config_json="./ds_config.$SLURM_JOBID.json"

# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > $config_json
{
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
"train_batch_size": $GLOBAL_BATCH_SIZE,
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": $ZERO_STAGE
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 12
},
"steps_per_print": 2000,
"wall_clock_breakdown": false
}
EOT


DEEPSPEED_ARGS=" \
--deepspeed \
--deepspeed_config ${config_json} \
--zero-stage ${ZERO_STAGE} \
--deepspeed-activation-checkpointing \
"

export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

# Data loading option 1:
#DATA_PATH="/gpfswork/rech/six/commun/bigscience-training/p3t0/p3_t0_train"
# --data-path $DATA_PATH \
# --split 100,0,0 \

export CMD=" \
`pwd`/finetune_t0_non_causal_decoder.py \
--tensor-model-parallel-size $TP_SIZE \
--pipeline-model-parallel-size $PP_SIZE \
$GPT_ARGS \
$OUTPUT_ARGS \
--train-weighted-split-paths-path $TRAIN_DATA_PATH \
--valid-weighted-split-paths-path $VALID_DATA_PATH \
--dataloader-type single \
--data-impl mmap \
--distributed-backend nccl \
$DEEPSPEED_ARGS \
"

echo $CMD

# do not remove or the training will hang and nodes will be lost w/o this workaround
export CUDA_LAUNCH_BLOCKING=1

# hide duplicated errors using this hack - will be properly fixed in pt-1.12
export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

clear; srun --jobid $SLURM_JOBID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID $CMD" 2>&1 | tee -a $LOGS_PATH/main_log.txt

echo "END TIME: $(date)"
Loading