Skip to content

Commit fa57b0b

Browse files
committed
Refactor ActivationCache and NormalizableMixin for improved clarity and functionality
This commit introduces the following changes: - Sequence ranges are now always stored when storing an activation cache and having store_tokens=False and shuffling=False. - Updated the `NormalizableMixin` to ensure that variance calculations are performed along the last dimension, with appropriate shape assertions added for clarity. This makes sure that for the crosscoder the variance is computed per-layer. - Adjusted the normalization and denormalization methods to maintain tensor shapes correctly during operations. - Enhanced logging in `CrossCoderTrainer` to include layer-wise RMS norms for better monitoring of training dynamics. These modifications improve the clarity and maintainability of the code while ensuring correct functionality in activation caching and normalization processes.
1 parent 48debb2 commit fa57b0b

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

dictionary_learning/cache.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -531,12 +531,9 @@ def collect(
531531
not shuffle_shards or not store_tokens
532532
), "Shuffling shards and storing tokens is not supported yet"
533533

534-
# Check if we need to store sequence ranges
535-
has_bos_token = model.tokenizer.bos_token is not None
536534
store_sequence_ranges = (
537535
store_tokens and
538-
not shuffle_shards and
539-
not has_bos_token
536+
not shuffle_shards
540537
)
541538

542539
dataloader = DataLoader(data, batch_size=batch_size, num_workers=num_workers)

dictionary_learning/dictionary.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
normalization is a no-op.
4646
activation_shape: Shape of the activation tensor. Required if activation_mean and activation_std are None for proper initialization and registration of the buffers.
4747
keep_relative_variance: If True, performs global scaling so that the
48-
sum of variances is 1 while their relative magnitudes stay unchanged. If false we normalize neuron-wise.
48+
sum of variances is 1 while their relative magnitudes stay unchanged. If false we normalize neuron-wise. We normalize the last dimension.
4949
target_rms: Target RMS for input activation normalization.
5050
"""
5151
super().__init__()
@@ -69,11 +69,12 @@ def __init__(
6969
self.register_buffer("activation_std", th.nan * th.ones(activation_shape))
7070

7171
if self.keep_relative_variance and self.has_activation_normalizer:
72-
total_var = (self.activation_std**2).sum()
72+
total_var = (self.activation_std**2).sum(dim=-1)
73+
assert total_var.shape == self.activation_mean.shape[:-1]
7374
activation_global_scale = self.target_rms / th.sqrt(total_var + 1e-8)
7475
self.register_buffer("activation_global_scale", activation_global_scale)
7576
else:
76-
self.register_buffer("activation_global_scale", th.tensor(1.0))
77+
self.register_buffer("activation_global_scale", th.ones(activation_shape[:-1]))
7778

7879
@property
7980
def has_activation_normalizer(self) -> bool:
@@ -103,7 +104,7 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
103104
x = x - self.activation_mean
104105

105106
if self.keep_relative_variance:
106-
return x * self.activation_global_scale
107+
return (x.T * self.activation_global_scale).T
107108
else:
108109
return x / (self.activation_std + 1e-8)
109110
return x
@@ -127,7 +128,7 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
127128
assert isinstance(self.activation_std, th.Tensor)
128129

129130
if self.keep_relative_variance:
130-
x = x / (self.activation_global_scale + 1e-8)
131+
x = (x.T / (self.activation_global_scale + 1e-8)).T
131132
else:
132133
x = x * (self.activation_std + 1e-8)
133134

dictionary_learning/trainers/crosscoder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,17 +204,20 @@ def loss(
204204
if not logging:
205205
return loss
206206
else:
207+
log_dict = {
208+
"l2_loss": l2_loss.item(),
209+
"mse_loss": mse_loss.item(),
210+
"sparsity_loss": l1_loss.item(),
211+
"loss": loss.item(),
212+
"deads": deads if return_deads else None,
213+
}
214+
for layer in range(x.shape[1]):
215+
log_dict[f"rms_norm_l{layer}"] = th.sqrt((x[:, layer, :].pow(2).sum(-1)).mean()).item()
207216
return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])(
208217
x,
209218
x_hat,
210219
f,
211-
{
212-
"l2_loss": l2_loss.item(),
213-
"mse_loss": mse_loss.item(),
214-
"sparsity_loss": l1_loss.item(),
215-
"loss": loss.item(),
216-
"deads": deads if return_deads else None,
217-
},
220+
log_dict,
218221
)
219222

220223
def update(self, step, activations):

dictionary_learning/training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def log_stats(
8989
stage: str = "train",
9090
use_threshold: bool = True,
9191
epoch_idx_per_step: Optional[List[int]] = None,
92+
num_tokens: int = None,
9293
):
9394
with th.no_grad():
9495
log = {}
@@ -111,6 +112,8 @@ def log_stats(
111112

112113
if epoch_idx_per_step is not None:
113114
log["epoch"] = epoch_idx_per_step[step]
115+
if num_tokens is not None:
116+
log["num_tokens"] = num_tokens
114117
wandb.log(log, step=step)
115118

116119

@@ -285,11 +288,12 @@ def trainSAE(
285288
with open(os.path.join(save_dir, "config.json"), "w") as f:
286289
json.dump(config, f, indent=4)
287290

291+
num_tokens = 0
288292
for step, act in enumerate(tqdm(data, total=steps)):
289293
if steps is not None and step >= steps:
290294
break
291295
act = act.to(trainer.device).to(dtype)
292-
296+
num_tokens += act.shape[0]
293297
# logging
294298
if log_steps is not None and step % log_steps == 0 and step != 0:
295299
with th.no_grad():
@@ -301,6 +305,7 @@ def trainSAE(
301305
transcoder,
302306
use_threshold=False,
303307
epoch_idx_per_step=epoch_idx_per_step,
308+
num_tokens=num_tokens,
304309
)
305310
if isinstance(trainer, BatchTopKCrossCoderTrainer) or isinstance(trainer, BatchTopKTrainer):
306311
log_stats(
@@ -312,6 +317,7 @@ def trainSAE(
312317
use_threshold=True,
313318
stage="trainthres",
314319
epoch_idx_per_step=epoch_idx_per_step,
320+
num_tokens=num_tokens,
315321
)
316322

317323
# saving

0 commit comments

Comments
 (0)