Skip to content

Commit 4ccaaec

Browse files
committed
inheritance bug
1 parent d147f95 commit 4ccaaec

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

dictionary_learning/dictionary.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .utils import set_decoder_norm_to_unit_norm, ActivationNormalizer
1818

1919

20-
class NormalizableMixin(ABC):
20+
class NormalizableMixin(nn.Module):
2121
"""
2222
Mixin class providing activation normalization functionality.
2323
@@ -34,6 +34,7 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
3434
activation_normalizer: Optional normalizer for activations. If None,
3535
normalization is a no-op.
3636
"""
37+
super().__init__()
3738
self.activation_normalizer = activation_normalizer
3839
if self.activation_normalizer is not None:
3940
self.activation_normalizer.to(self.device)
@@ -400,7 +401,7 @@ def from_pretrained(
400401
return autoencoder.to(dtype=dtype, device=device)
401402

402403

403-
class BatchTopKSAE(Dictionary, nn.Module, NormalizableMixin):
404+
class BatchTopKSAE(NormalizableMixin, Dictionary):
404405
"""
405406
Batch Top-K Sparse Autoencoder implementation.
406407
@@ -943,7 +944,7 @@ def __repr__(self) -> str:
943944
return self.name
944945

945946

946-
class CrossCoder(Dictionary, nn.Module, NormalizableMixin):
947+
class CrossCoder(Dictionary, NormalizableMixin):
947948
"""
948949
A crosscoder sparse autoencoder for multi-layer activation processing.
949950

0 commit comments

Comments
 (0)