We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9e75f1f commit b870fabCopy full SHA for b870fab
dictionary_learning/dictionary.py
@@ -33,7 +33,7 @@ def __init__(
33
activation_shape: tuple[int, ...] | None = None,
34
*,
35
keep_relative_variance: bool = True,
36
- target_rms: float = 1.0,
+ target_rms: float | None = 1.0,
37
):
38
"""
39
Initialize the normalization mixin.
@@ -50,6 +50,8 @@ def __init__(
50
51
super().__init__()
52
self.keep_relative_variance = keep_relative_variance
53
+ if target_rms is None:
54
+ target_rms = 1.0
55
self.register_buffer("target_rms", th.tensor(target_rms))
56
if activation_mean is not None and activation_std is not None:
57
# Type assertion to help linter understand these are tensors
0 commit comments