@@ -137,14 +137,21 @@ public SGD(IEnumerable<ParamGroup> parameters, double lr, double momentum = 0.0,
137137 public override Tensor step ( Func < Tensor > closure = null )
138138 {
139139 return _step < ParamGroup > ( group => {
140-
140+ #nullable enable
141141 var options = group . Options ;
142- var momentum = options . momentum . Value ;
143- var dampening = options . dampening . Value ;
144- var weight_decay = options . weight_decay . Value ;
145- var nesterov = options . nesterov . Value ;
146- var maximize = options . maximize . Value ;
147- var lr = options . LearningRate . Value ;
142+ var momentum = options . momentum ! . Value ;
143+ var need_momentum = momentum != 0 ;
144+ using var momentum_scalar = ( need_momentum ) ? momentum . ToScalar ( ) : null ;
145+ var dampening = options . dampening ! . Value ;
146+ var need_dampening = dampening != 1 ;
147+ using var dampening_bar_scalar = ( need_momentum && need_dampening ) ? ( 1 - dampening ) . ToScalar ( ) : null ;
148+ var weight_decay = options . weight_decay ! . Value ;
149+ var need_weight_decay = weight_decay != 0 ;
150+ using var weight_decay_scalar = ( need_weight_decay ) ? weight_decay . ToScalar ( ) : null ;
151+ var nesterov = options . nesterov ! . Value ;
152+ var maximize = options . maximize ! . Value ;
153+ var lr = options . LearningRate ! . Value ;
154+ using var signed_lr_scalar = ( ( maximize ) ? lr : - lr ) . ToScalar ( ) ;
148155
149156 foreach ( var param in group . Parameters ) {
150157
@@ -154,33 +161,32 @@ public override Tensor step(Func<Tensor> closure = null)
154161
155162 if ( grad is null ) continue ;
156163
157- if ( weight_decay != 0 ) {
158- grad = grad . add ( param , alpha : weight_decay ) ;
159- }
164+ if ( need_weight_decay ) grad = grad . add ( param , alpha : weight_decay_scalar ! ) ;
160165
161- if ( momentum != 0 ) {
166+ if ( need_momentum ) {
162167 var buf = state . momentum_buffer ;
163168
164169 if ( buf is null ) {
165170 buf = grad . clone ( ) . detach ( ) . DetachFromDisposeScope ( ) ;
166171 state . momentum_buffer = buf ;
167172 } else {
168- buf . mul_ ( momentum ) . add_ ( grad , alpha : ( 1 - dampening ) ) ;
173+ buf . mul_ ( momentum_scalar ! ) ;
174+ if ( need_dampening ) buf . add_ ( grad , alpha : dampening_bar_scalar ! ) ;
169175 }
170176
171177 if ( nesterov ) {
172- grad = grad . add ( buf , alpha : momentum ) ;
178+ grad = grad . add ( buf , alpha : momentum_scalar ! ) ;
173179 } else {
174180 grad = buf ;
175181 }
176182
177183 state . momentum_buffer = buf ;
178184 }
179185
180- var alpha = maximize ? lr : - lr ;
181- param . add_ ( grad , alpha : alpha ) ;
186+ param . add_ ( grad , alpha : signed_lr_scalar ) ;
182187
183188 }
189+ #nullable disable
184190 } , closure ) ;
185191 }
186192
0 commit comments