Skip to content

Commit 2c27cc1

Browse files
quic-swatiaSwati Allabadi
andauthored
[QEff. Finetune]: Correcting num_steps trained as per max_train_step and displaying non scaled loss value on console. (#527)
Signed-off-by: Swati Allabadi <[email protected]> Co-authored-by: Swati Allabadi <[email protected]>
1 parent 5d381b7 commit 2c27cc1

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,9 @@ def train(
124124

125125
if train_config.use_peft and train_config.from_peft_checkpoint:
126126
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
127+
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
127128
if epoch < intermediate_epoch:
128129
logger.log_rank_zero(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
129-
# to bring the count of train_step in sync with where it left off
130-
total_train_steps += len(train_dataloader)
131130
continue
132131

133132
logger.log_rank_zero(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
@@ -149,20 +148,18 @@ def train(
149148

150149
num_dummy_samples = 0
151150
for step, batch in enumerate(train_dataloader):
151+
# total_train_steps indicates the cumulative number of training steps completed across all epochs.
152+
# When resuming fine-tuning from previously saved checkpoints, total_train_steps indicates the total number of steps trained across the earlier session and the ongoing one.
153+
total_train_steps = (epoch) * len(train_dataloader) + step
152154
# resume training from a particular checkpoint, assuming the dataset is not shuffled
153155
if train_config.use_peft and train_config.from_peft_checkpoint:
154-
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
155-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
156156
# to bring the count of train_step in sync with where it left off
157157
if epoch == intermediate_epoch and step == 0:
158-
total_train_steps += intermediate_step
159158
logger.log_rank_zero(
160159
f"Skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for it."
161160
)
162161
if epoch == intermediate_epoch and step < intermediate_step:
163-
total_train_steps += 1
164162
continue
165-
total_train_steps += 1
166163

167164
if train_config.max_train_step > 0 and total_train_steps >= train_config.max_train_step:
168165
max_steps_reached = True
@@ -235,12 +232,12 @@ def train(
235232
else:
236233
num_samples_in_cur_update = len(train_dataloader) % train_config.gradient_accumulation_steps
237234

238-
loss = loss / num_samples_in_cur_update
235+
normalized_loss = loss / num_samples_in_cur_update
239236

240237
if train_config.grad_scaler:
241-
scaler.scale(loss).backward() # backward pass
238+
scaler.scale(normalized_loss).backward() # backward pass
242239
else:
243-
loss.backward() # backward pass
240+
normalized_loss.backward() # backward pass
244241

245242
if is_optimizer_step:
246243
if train_config.grad_scaler:

0 commit comments

Comments
 (0)