Skip to content

Commit 7b4b216

Browse files
committed
Enhance model handling and logging in training process
This commit introduces several improvements to the training module, including: - Addition of the `get_model` function to streamline model retrieval from the trainer, ensuring consistent handling of compiled models. - Updates to the `log_stats` and `run_validation` functions to incorporate an optional `epoch_idx_per_step` parameter for enhanced logging capabilities. - Refactoring of the `save_model` function to utilize the new `get_model` function, simplifying the code and improving clarity. - Documentation enhancements in the `trainSAE` function, including detailed parameter descriptions and assertions for validation data. These changes aim to improve the clarity and maintainability of the training code while enhancing logging functionality for better tracking of training progress.
1 parent 4ccaaec commit 7b4b216

File tree

3 files changed

+55
-20
lines changed

3 files changed

+55
-20
lines changed

dictionary_learning/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder, CrossCoder
1+
from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder, CrossCoder, BatchTopKSAE, BatchTopKCrossCoder
22
from .buffer import ActivationBuffer

dictionary_learning/trainers/batch_top_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
get_lr_schedule,
99
set_decoder_norm_to_unit_norm,
1010
remove_gradient_parallel_to_decoder_directions,
11-
ActivationNormalizer,
1211
)
12+
from ..dictionary import ActivationNormalizer
1313

1414

1515
class BatchTopKTrainer(SAETrainer):

dictionary_learning/training.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tqdm import tqdm
1010
from warnings import warn
1111
import wandb
12+
from typing import List, Optional
1213

1314
from .trainers.crosscoder import CrossCoderTrainer, BatchTopKCrossCoderTrainer
1415

@@ -67,6 +68,14 @@ def get_stats(
6768
out["frac_variance_explained"] = frac_variance_explained.item()
6869
return out
6970

71+
def get_model(trainer):
72+
if hasattr(trainer, "ae"):
73+
model = trainer.ae
74+
else:
75+
model = trainer.model
76+
if hasattr(model, "_orig_mod"): # Check if model is compiled
77+
model = model._orig_mod
78+
return model
7079

7180
def log_stats(
7281
trainer,
@@ -76,6 +85,7 @@ def log_stats(
7685
transcoder: bool,
7786
stage: str = "train",
7887
use_threshold: bool = True,
88+
epoch_idx_per_step: Optional[List[int]] = None,
7989
):
8090
with th.no_grad():
8191
log = {}
@@ -96,15 +106,15 @@ def log_stats(
96106
for name, value in trainer_log.items():
97107
log[f"{stage}/{name}"] = value
98108

99-
wandb.log(log, step=step)
100-
109+
wandb.log(log, step=step, epoch=epoch_idx_per_step[step] if epoch_idx_per_step is not None else None)
101110

102111
@th.no_grad()
103112
def run_validation(
104113
trainer,
105114
validation_data,
106115
step: int = None,
107116
dtype: th.dtype = th.float32,
117+
epoch_idx_per_step: Optional[List[int]] = None,
108118
):
109119
l0 = []
110120
frac_variance_explained = []
@@ -167,24 +177,15 @@ def run_validation(
167177
).mean()
168178
if step is not None:
169179
log["step"] = step
170-
wandb.log(log, step=step)
180+
wandb.log(log, step=step, epoch=epoch_idx_per_step[step] if epoch_idx_per_step is not None else None)
171181

172182
return log
173183

174184

175185
def save_model(trainer, checkpoint_name, save_dir):
176186
os.makedirs(save_dir, exist_ok=True)
177-
# Handle the case where the model might be compiled
178-
if hasattr(trainer, "ae"):
179-
model = trainer.ae
180-
if hasattr(model, "_orig_mod"): # Check if model is compiled
181-
model = model._orig_mod
182-
th.save(model.state_dict(), os.path.join(save_dir, checkpoint_name))
183-
else:
184-
model = trainer.model
185-
if hasattr(model, "_orig_mod"): # Check if model is compiled
186-
model = model._orig_mod
187-
th.save(model.state_dict(), os.path.join(save_dir, checkpoint_name))
187+
model = get_model(trainer)
188+
th.save(model.state_dict(), os.path.join(save_dir, checkpoint_name))
188189

189190

190191
def trainSAE(
@@ -206,9 +207,39 @@ def trainSAE(
206207
save_last_eval=True,
207208
start_of_training_eval=False,
208209
dtype=th.float32,
210+
run_wandb_finish=True,
211+
epoch_idx_per_step: Optional[List[int]] = None,
209212
):
210213
"""
211214
Train SAE using the given trainer
215+
216+
Args:
217+
data: Training data iterator/dataloader
218+
trainer_config: Configuration dictionary for the trainer
219+
use_wandb: Whether to use Weights & Biases logging (default: False)
220+
wandb_entity: W&B entity name (default: "")
221+
wandb_project: W&B project name (default: "")
222+
steps: Maximum number of training steps (default: None)
223+
save_steps: Frequency of model checkpointing (default: None)
224+
save_dir: Directory to save checkpoints and config (default: None)
225+
log_steps: Frequency of logging statistics (default: None)
226+
activations_split_by_head: Whether activations are split by attention head (default: False)
227+
validate_every_n_steps: Frequency of validation evaluation (default: None)
228+
validation_data: Validation data iterator/dataloader (default: None)
229+
transcoder: Whether training a transcoder model (default: False)
230+
run_cfg: Additional run configuration (default: {})
231+
end_of_step_logging_fn: Custom logging function called at end of each step (default: None)
232+
save_last_eval: Whether to save evaluation results at end of training (default: True)
233+
start_of_training_eval: Whether to run evaluation before training starts (default: False)
234+
dtype: Training data type (default: torch.float32)
235+
run_wandb_finish: Whether to call wandb.finish() at end of training (default: True)
236+
epoch_idx_per_step: Optional mapping of training steps to epoch indices (default: None). Mainly used for logging when the dataset is pre-shuffled and contains multiple epochs.
237+
238+
Returns:
239+
Trained model
240+
241+
Raises:
242+
AssertionError: If validation_data is None but validate_every_n_steps is specified
212243
"""
213244
assert not (
214245
validation_data is None and validate_every_n_steps is not None
@@ -229,7 +260,7 @@ def trainSAE(
229260

230261
trainer.model.to(dtype)
231262

232-
# make save dir, export config
263+
# make save dir, export config
233264
if save_dir is not None:
234265
os.makedirs(save_dir, exist_ok=True)
235266
# save config
@@ -257,6 +288,7 @@ def trainSAE(
257288
activations_split_by_head,
258289
transcoder,
259290
use_threshold=False,
291+
epoch_idx_per_step=epoch_idx_per_step,
260292
)
261293
if isinstance(trainer, BatchTopKCrossCoderTrainer):
262294
log_stats(
@@ -267,6 +299,7 @@ def trainSAE(
267299
transcoder,
268300
use_threshold=True,
269301
stage="trainthres",
302+
epoch_idx_per_step=epoch_idx_per_step,
270303
)
271304

272305
# saving
@@ -284,7 +317,7 @@ def trainSAE(
284317
and (start_of_training_eval or step > 0)
285318
):
286319
print(f"Validating at step {step}")
287-
logs = run_validation(trainer, validation_data, step=step, dtype=dtype)
320+
logs = run_validation(trainer, validation_data, step=step, dtype=dtype, epoch_idx_per_step=epoch_idx_per_step)
288321
try:
289322
os.makedirs(save_dir, exist_ok=True)
290323
th.save(logs, os.path.join(save_dir, f"eval_logs_{step}.pt"))
@@ -295,7 +328,7 @@ def trainSAE(
295328
end_of_step_logging_fn(trainer, step)
296329
try:
297330
last_eval_logs = run_validation(
298-
trainer, validation_data, step=step, dtype=dtype
331+
trainer, validation_data, step=step, dtype=dtype, epoch_idx_per_step=epoch_idx_per_step
299332
)
300333
if save_last_eval:
301334
os.makedirs(save_dir, exist_ok=True)
@@ -307,5 +340,7 @@ def trainSAE(
307340
if save_dir is not None:
308341
save_model(trainer, f"model_final.pt", save_dir)
309342

310-
if use_wandb:
343+
if use_wandb and run_wandb_finish:
311344
wandb.finish()
345+
346+
return get_model(trainer)

0 commit comments

Comments
 (0)