17
17
from .utils import set_decoder_norm_to_unit_norm , ActivationNormalizer
18
18
19
19
20
- class NormalizableMixin (ABC ):
20
+ class NormalizableMixin (nn . Module ):
21
21
"""
22
22
Mixin class providing activation normalization functionality.
23
23
@@ -34,6 +34,7 @@ def __init__(self, activation_normalizer: ActivationNormalizer | None = None):
34
34
activation_normalizer: Optional normalizer for activations. If None,
35
35
normalization is a no-op.
36
36
"""
37
+ super ().__init__ ()
37
38
self .activation_normalizer = activation_normalizer
38
39
if self .activation_normalizer is not None :
39
40
self .activation_normalizer .to (self .device )
@@ -400,7 +401,7 @@ def from_pretrained(
400
401
return autoencoder .to (dtype = dtype , device = device )
401
402
402
403
403
- class BatchTopKSAE (Dictionary , nn . Module , NormalizableMixin ):
404
+ class BatchTopKSAE (NormalizableMixin , Dictionary ):
404
405
"""
405
406
Batch Top-K Sparse Autoencoder implementation.
406
407
@@ -943,7 +944,7 @@ def __repr__(self) -> str:
943
944
return self .name
944
945
945
946
946
- class CrossCoder (Dictionary , nn . Module , NormalizableMixin ):
947
+ class CrossCoder (Dictionary , NormalizableMixin ):
947
948
"""
948
949
A crosscoder sparse autoencoder for multi-layer activation processing.
949
950
0 commit comments