You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Enhance model handling and logging in training process
This commit introduces several improvements to the training module, including:
- Addition of the `get_model` function to streamline model retrieval from the trainer, ensuring consistent handling of compiled models.
- Updates to the `log_stats` and `run_validation` functions to incorporate an optional `epoch_idx_per_step` parameter for enhanced logging capabilities.
- Refactoring of the `save_model` function to utilize the new `get_model` function, simplifying the code and improving clarity.
- Documentation enhancements in the `trainSAE` function, including detailed parameter descriptions and assertions for validation data.
These changes aim to improve the clarity and maintainability of the training code while enhancing logging functionality for better tracking of training progress.
trainer_config: Configuration dictionary for the trainer
219
+
use_wandb: Whether to use Weights & Biases logging (default: False)
220
+
wandb_entity: W&B entity name (default: "")
221
+
wandb_project: W&B project name (default: "")
222
+
steps: Maximum number of training steps (default: None)
223
+
save_steps: Frequency of model checkpointing (default: None)
224
+
save_dir: Directory to save checkpoints and config (default: None)
225
+
log_steps: Frequency of logging statistics (default: None)
226
+
activations_split_by_head: Whether activations are split by attention head (default: False)
227
+
validate_every_n_steps: Frequency of validation evaluation (default: None)
228
+
validation_data: Validation data iterator/dataloader (default: None)
229
+
transcoder: Whether training a transcoder model (default: False)
230
+
run_cfg: Additional run configuration (default: {})
231
+
end_of_step_logging_fn: Custom logging function called at end of each step (default: None)
232
+
save_last_eval: Whether to save evaluation results at end of training (default: True)
233
+
start_of_training_eval: Whether to run evaluation before training starts (default: False)
234
+
dtype: Training data type (default: torch.float32)
235
+
run_wandb_finish: Whether to call wandb.finish() at end of training (default: True)
236
+
epoch_idx_per_step: Optional mapping of training steps to epoch indices (default: None). Mainly used for logging when the dataset is pre-shuffled and contains multiple epochs.
237
+
238
+
Returns:
239
+
Trained model
240
+
241
+
Raises:
242
+
AssertionError: If validation_data is None but validate_every_n_steps is specified
0 commit comments