-
Notifications
You must be signed in to change notification settings - Fork 29.9k
Description
Reproduction
The current implementation in the Trainer
's _inner_training_loop
for resuming from a checkpoint calls _load_rng_state
after fetching the data batch with get_batch_samples
. This logic appears to be designed to handle the complexities of skip_first_batches
and multi-worker dataloading.
However, this order can break true reproducibility if the data loading process itself involves random operations (e.g., in-batch transformations, random samplers, or datasets with random augmentations in __getitem__
).
When get_batch_samples
is called before _load_rng_state
, the random operations within the dataloader consume the RNG state from the current execution stream, not the restored one. This leads to two issues:
- The data batch fetched is different from the one in the original, uninterrupted run.
- The subsequent
_load_rng_state
call resets the RNG, but thetraining_step
(e.g., with Dropout) then operates on this incorrect data, and uses an RNG state that is out of sync with the original run's state progression.
Conversely, if _load_rng_state
is called before get_batch_samples
, the entire sequence of random events (data loading + model training) can be perfectly reproduced, as demonstrated by the experiment below.
This suggests a potential conflict between the current implementation's robustness for skip_first_batches
and its ability to ensure bit-for-bit reproducibility in all data loading scenarios.
Reproducible code
The following script sets up a controlled experiment to demonstrate the issue. It simulates a training run that is interrupted and then resumed. It compares two hypotheses for the recovery order:
- Hypothesis A:
load_rng_state()
->get_batch_samples()
- Hypothesis B:
get_batch_samples()
->load_rng_state()
(This mimics the currentTrainer
logic)
The experiment clearly shows that only Hypothesis A successfully reproduces the "golden standard" output from the uninterrupted run.
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, Dataset
import os
import random
import numpy as np
import shutil
# --- Core Components: Setup for the experiment ---
def set_seed(seed):
"""Set all random seeds for reproducibility."""
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def save_rng_state(path):
"""Saves the complete RNG state."""
os.makedirs(path, exist_ok=True)
states = {
'torch_rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'numpy_rng_state': np.random.get_state(),
'python_rng_state': random.getstate(),
}
torch.save(states, os.path.join(path, 'rng_state.pth'))
def load_rng_state(path):
"""Loads the complete RNG state."""
states = torch.load(os.path.join(path, 'rng_state.pth'), weights_only=False)
torch.set_rng_state(states['torch_rng_state'])
if torch.cuda.is_available() and states['cuda_rng_state']:
torch.cuda.set_rng_state(states['cuda_rng_state'])
np.random.set_state(states['numpy_rng_state'])
random.setstate(states['python_rng_state'])
class RandomTransformDataset(Dataset):
"""A dataset that applies a random transform, making data loading a random process."""
def __init__(self, underlying_dataset):
self.underlying_dataset = underlying_dataset
def __len__(self):
return len(self.underlying_dataset)
def __getitem__(self, idx):
data, label = self.underlying_dataset[idx]
# This random operation makes the data loading process itself non-deterministic
# without proper RNG state management.
noise = torch.rand(data.shape) * 0.001
return data + noise, label
class SimpleModel(nn.Module):
"""A simple model with Dropout, making the training step a random process."""
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
return self.dropout(self.linear(x))
# --- Experiment Execution ---
# Experiment parameters
SEED = 42
CKPT_DIR = "./experiment_ckpt"
INTERRUPT_STEP = 3
STEP_TO_VERIFY = INTERRUPT_STEP + 1
# Prepare the base dataset
base_dataset = TensorDataset(torch.randn(100, 10), torch.randn(100, 10))
print("="*60)
print(" REPRODUCIBILITY EXPERIMENT ")
print("="*60)
print(f"Goal: Determine the correct order of `load_rng_state` and `get_batch_samples`")
print(f" to reproduce training from a checkpoint at step {INTERRUPT_STEP}.")
print(f"Verification will happen at step {STEP_TO_VERIFY}.\n")
# 1. CONTROL GROUP: The uninterrupted run
print("--- [1] Running the Control Group (Golden Standard) ---")
set_seed(SEED)
model = SimpleModel()
model.train()
control_dataloader = DataLoader(RandomTransformDataset(base_dataset), batch_size=10, shuffle=True, num_workers=0)
control_iterator = iter(control_dataloader)
golden_output = None
for step in range(STEP_TO_VERIFY):
inputs, _ = next(control_iterator)
output = model(inputs)
if step + 1 == INTERRUPT_STEP:
print(f"Step {step+1}: Saving checkpoint...")
save_rng_state(CKPT_DIR)
if step + 1 == STEP_TO_VERIFY:
golden_output = output.detach().clone()
print(f"Step {step+1}: Storing golden output. Sum = {golden_output.sum().item():.6f}")
print("\n--- [2] Running Experimental Groups ---")
# Function to run one experimental hypothesis
def run_experiment(hypothesis_name, restore_order):
print(f"\n--- Testing Hypothesis {hypothesis_name} ---")
print(f"Restore order: {restore_order}")
set_seed(SEED)
model_exp = SimpleModel()
model_exp.train()
exp_dataloader = DataLoader(RandomTransformDataset(base_dataset), batch_size=10, shuffle=True, num_workers=0)
exp_iterator = iter(exp_dataloader)
# Skip to the point of interruption
for _ in range(INTERRUPT_STEP):
next(exp_iterator)
# Apply the hypothesis's restore order
if restore_order == "load_first":
print("Action: Loading RNG state...")
load_rng_state(CKPT_DIR)
print("Action: Getting batch samples...")
inputs, _ = next(exp_iterator)
elif restore_order == "get_first":
print("Action: Getting batch samples...")
inputs, _ = next(exp_iterator)
print("Action: Loading RNG state...")
load_rng_state(CKPT_DIR)
else:
raise ValueError("Unknown restore order")
# Perform the training step
print("Action: Performing training step...")
output_exp = model_exp(inputs)
print(f"Result: Output sum = {output_exp.sum().item():.6f}")
return output_exp.detach().clone()
# Hypothesis A: load -> get
output_A = run_experiment("A: load_first", "load_first")
# Hypothesis B: get -> load (Mimics Trainer)
output_B = run_experiment("B: get_first", "get_first")
# 4. ANALYSIS & CONCLUSION
print("\n\n" + "="*60)
print(" ANALYSIS AND CONCLUSION ")
print("="*60)
print(f"Golden Standard Output Sum (from Control Group): {golden_output.sum().item():.6f}\n")
print(f"Hypothesis A ('load_first') Output Sum: {output_A.sum().item():.6f}")
match_A = torch.allclose(golden_output, output_A)
print(f"Does it match the Golden Standard? -> {match_A}")
print(f"\nHypothesis B ('get_first') Output Sum: {output_B.sum().item():.6f}")
match_B = torch.allclose(golden_output, output_B)
print(f"Does it match the Golden Standard? -> {match_B}\n")
print("--- Conclusion based on experimental data ---")
if match_A and not match_B:
print("✅ The experiment confirms that Hypothesis A ('load_first') is the correct approach for full reproducibility.")
elif not match_A and match_B:
print("✅ The experiment confirms that Hypothesis B ('get_first') is correct.")
else:
print("❓ The experiment is inconclusive or failed. Please check the setup.")
# Cleanup
shutil.rmtree(CKPT_DIR)
Expected behavior
To achieve perfect reproducibility, Hypothesis A (load_rng_state
before get_batch_samples
) should be the correct approach, as it restores the RNG state before any random operations for the resumed step occur.
The current implementation in the Trainer
(mimicked by Hypothesis B) fails to reproduce the original run in this scenario.
I understand there are complexities, especially with skip_first_batches
potentially consuming the RNG state if loaded too early. This issue is intended to highlight this trade-off and start a discussion on whether a more robust solution for perfect reproducibility can be found.
Thank you for your time and for maintaining this incredible library.