Skip to content

Commit 30cfeb0

Browse files
committed
added fix for how the normalization works on the crosscoder
1 parent fa57b0b commit 30cfeb0

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

dictionary_learning/dictionary.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,11 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
9898
if self.has_activation_normalizer:
9999
if not inplace:
100100
x = x.clone()
101-
# Type assertions for linter
102-
assert isinstance(self.activation_mean, th.Tensor)
103-
assert isinstance(self.activation_std, th.Tensor)
101+
assert x.shape[1:-1] == self.activation_global_scale.shape, "Normalization shape mismatch"
104102
x = x - self.activation_mean
105103

106104
if self.keep_relative_variance:
107-
return (x.T * self.activation_global_scale).T
105+
return x * self.activation_global_scale.unsqueeze(0).unsqueeze(-1)
108106
else:
109107
return x / (self.activation_std + 1e-8)
110108
return x
@@ -123,12 +121,10 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
123121
if self.has_activation_normalizer:
124122
if not inplace:
125123
x = x.clone()
126-
# Type assertions for linter
127-
assert isinstance(self.activation_mean, th.Tensor)
128-
assert isinstance(self.activation_std, th.Tensor)
124+
assert x.shape[1:-1] == self.activation_global_scale.shape, "Normalization shape mismatch"
129125

130126
if self.keep_relative_variance:
131-
x = (x.T / (self.activation_global_scale + 1e-8)).T
127+
x = x / (self.activation_global_scale.unsqueeze(0).unsqueeze(-1) + 1e-8)
132128
else:
133129
x = x * (self.activation_std + 1e-8)
134130

0 commit comments

Comments
 (0)