Skip to content

Commit b870fab

Browse files
committed
fixed broken retro compatibility
1 parent 9e75f1f commit b870fab

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

dictionary_learning/dictionary.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
activation_shape: tuple[int, ...] | None = None,
3434
*,
3535
keep_relative_variance: bool = True,
36-
target_rms: float = 1.0,
36+
target_rms: float | None = 1.0,
3737
):
3838
"""
3939
Initialize the normalization mixin.
@@ -50,6 +50,8 @@ def __init__(
5050
"""
5151
super().__init__()
5252
self.keep_relative_variance = keep_relative_variance
53+
if target_rms is None:
54+
target_rms = 1.0
5355
self.register_buffer("target_rms", th.tensor(target_rms))
5456
if activation_mean is not None and activation_std is not None:
5557
# Type assertion to help linter understand these are tensors

0 commit comments

Comments
 (0)