@@ -124,10 +124,9 @@ def train(
124
124
125
125
if train_config .use_peft and train_config .from_peft_checkpoint :
126
126
intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
127
+ intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
127
128
if epoch < intermediate_epoch :
128
129
logger .log_rank_zero (f"Skipping epoch { epoch + 1 } since fine tuning has already completed for it." )
129
- # to bring the count of train_step in sync with where it left off
130
- total_train_steps += len (train_dataloader )
131
130
continue
132
131
133
132
logger .log_rank_zero (f"Starting epoch { epoch + 1 } /{ train_config .num_epochs } " )
@@ -149,20 +148,18 @@ def train(
149
148
150
149
num_dummy_samples = 0
151
150
for step , batch in enumerate (train_dataloader ):
151
+ # total_train_steps indicates the cumulative number of training steps completed across all epochs.
152
+ # When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
153
+ total_train_steps = (epoch ) * len (train_dataloader ) + step
152
154
# resume training from a particular checkpoint, assuming the dataset is not shuffled
153
155
if train_config .use_peft and train_config .from_peft_checkpoint :
154
- intermediate_step = int (train_config .from_peft_checkpoint .split ("/" )[- 1 ].split ("_" )[- 1 ])
155
- intermediate_epoch = int (train_config .from_peft_checkpoint .split ("/" )[- 2 ].split ("_" )[- 1 ]) - 1
156
156
# to bring the count of train_step in sync with where it left off
157
157
if epoch == intermediate_epoch and step == 0 :
158
- total_train_steps += intermediate_step
159
158
logger .log_rank_zero (
160
159
f"Skipping first { intermediate_step } steps for epoch { epoch + 1 } , since fine tuning has already completed for it."
161
160
)
162
161
if epoch == intermediate_epoch and step < intermediate_step :
163
- total_train_steps += 1
164
162
continue
165
- total_train_steps += 1
166
163
167
164
if train_config .max_train_step > 0 and total_train_steps >= train_config .max_train_step :
168
165
max_steps_reached = True
@@ -235,12 +232,12 @@ def train(
235
232
else :
236
233
num_samples_in_cur_update = len (train_dataloader ) % train_config .gradient_accumulation_steps
237
234
238
- loss = loss / num_samples_in_cur_update
235
+ normalized_loss = loss / num_samples_in_cur_update
239
236
240
237
if train_config .grad_scaler :
241
- scaler .scale (loss ).backward () # backward pass
238
+ scaler .scale (normalized_loss ).backward () # backward pass
242
239
else :
243
- loss .backward () # backward pass
240
+ normalized_loss .backward () # backward pass
244
241
245
242
if is_optimizer_step :
246
243
if train_config .grad_scaler :
0 commit comments