You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Refactor imports and enhance activation normalization handling
This commit includes the following changes:
- Reformatted the import statements in `__init__.py` for improved readability.
- Increased the sleep duration in `ActivationCache` from 1 to 10 seconds to allow more time for save processes to complete.
- Updated the `BatchTopKSAE`, `CrossCoder`, and `BatchTopKCrossCoder` classes to load the `activation_normalizer` from the state dictionary, ensuring that normalization is applied correctly during model initialization.
- Refined the normalization checks in the `CrossCoderEncoder` and `CrossCoderDecoder` classes to ensure that normalization only occurs if an `activation_normalizer` is present.
- Made minor formatting adjustments in the `training.py` file for better code clarity.
These changes aim to enhance the clarity and maintainability of the code while ensuring proper handling of activation normalization across various components.
trainer_config: Configuration dictionary for the trainer
219
227
use_wandb: Whether to use Weights & Biases logging (default: False)
220
228
wandb_entity: W&B entity name (default: "")
221
229
wandb_project: W&B project name (default: "")
230
+
wandb_group: W&B group name (default: "")
222
231
steps: Maximum number of training steps (default: None)
223
232
save_steps: Frequency of model checkpointing (default: None)
224
233
save_dir: Directory to save checkpoints and config (default: None)
@@ -234,10 +243,10 @@ def trainSAE(
234
243
dtype: Training data type (default: torch.float32)
235
244
run_wandb_finish: Whether to call wandb.finish() at end of training (default: True)
236
245
epoch_idx_per_step: Optional mapping of training steps to epoch indices (default: None). Mainly used for logging when the dataset is pre-shuffled and contains multiple epochs.
237
-
246
+
238
247
Returns:
239
248
Trained model
240
-
249
+
241
250
Raises:
242
251
AssertionError: If validation_data is None but validate_every_n_steps is specified
0 commit comments