@@ -98,13 +98,11 @@ def normalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Tenso
98
98
if self .has_activation_normalizer :
99
99
if not inplace :
100
100
x = x .clone ()
101
- # Type assertions for linter
102
- assert isinstance (self .activation_mean , th .Tensor )
103
- assert isinstance (self .activation_std , th .Tensor )
101
+ assert x .shape [1 :- 1 ] == self .activation_global_scale .shape , "Normalization shape mismatch"
104
102
x = x - self .activation_mean
105
103
106
104
if self .keep_relative_variance :
107
- return ( x . T * self .activation_global_scale ). T
105
+ return x * self .activation_global_scale . unsqueeze ( 0 ). unsqueeze ( - 1 )
108
106
else :
109
107
return x / (self .activation_std + 1e-8 )
110
108
return x
@@ -123,12 +121,10 @@ def denormalize_activations(self, x: th.Tensor, inplace: bool = False) -> th.Ten
123
121
if self .has_activation_normalizer :
124
122
if not inplace :
125
123
x = x .clone ()
126
- # Type assertions for linter
127
- assert isinstance (self .activation_mean , th .Tensor )
128
- assert isinstance (self .activation_std , th .Tensor )
124
+ assert x .shape [1 :- 1 ] == self .activation_global_scale .shape , "Normalization shape mismatch"
129
125
130
126
if self .keep_relative_variance :
131
- x = ( x . T / (self .activation_global_scale + 1e-8 )). T
127
+ x = x / (self .activation_global_scale . unsqueeze ( 0 ). unsqueeze ( - 1 ) + 1e-8 )
132
128
else :
133
129
x = x * (self .activation_std + 1e-8 )
134
130
0 commit comments