Skip to content

Commit 4a7cbd2

Browse files
Update SGD.step.
* Update src/TorchSharp/Optimizers/SGD.cs. + Update SGD.step. - Declare TorchSharp.Scalar explicitly. - Cache momentum != 0. - Cache dampening != 1. - Cache weight_decay != 0. - Omit unused TorchSharp.Scalar construction.
1 parent cee4875 commit 4a7cbd2

File tree

1 file changed

+21
-15
lines changed
  • src/TorchSharp/Optimizers

1 file changed

+21
-15
lines changed

src/TorchSharp/Optimizers/SGD.cs

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,21 @@ public SGD(IEnumerable<ParamGroup> parameters, double lr, double momentum = 0.0,
137137
public override Tensor step(Func<Tensor> closure = null)
138138
{
139139
return _step<ParamGroup>(group => {
140-
140+
#nullable enable
141141
var options = group.Options;
142-
var momentum = options.momentum.Value;
143-
var dampening = options.dampening.Value;
144-
var weight_decay = options.weight_decay.Value;
145-
var nesterov = options.nesterov.Value;
146-
var maximize = options.maximize.Value;
147-
var lr = options.LearningRate.Value;
142+
var momentum = options.momentum!.Value;
143+
var need_momentum = momentum != 0;
144+
using var momentum_scalar = (need_momentum) ? momentum.ToScalar() : null;
145+
var dampening = options.dampening!.Value;
146+
var need_dampening = dampening != 1;
147+
using var dampening_bar_scalar = (need_momentum && need_dampening) ? (1 - dampening).ToScalar() : null;
148+
var weight_decay = options.weight_decay!.Value;
149+
var need_weight_decay = weight_decay != 0;
150+
using var weight_decay_scalar = (need_weight_decay) ? weight_decay.ToScalar() : null;
151+
var nesterov = options.nesterov!.Value;
152+
var maximize = options.maximize!.Value;
153+
var lr = options.LearningRate!.Value;
154+
using var signed_lr_scalar = ((maximize) ? lr : -lr).ToScalar();
148155

149156
foreach (var param in group.Parameters) {
150157

@@ -154,33 +161,32 @@ public override Tensor step(Func<Tensor> closure = null)
154161

155162
if (grad is null) continue;
156163

157-
if (weight_decay != 0) {
158-
grad = grad.add(param, alpha: weight_decay);
159-
}
164+
if (need_weight_decay) grad = grad.add(param, alpha: weight_decay_scalar!);
160165

161-
if (momentum != 0) {
166+
if (need_momentum) {
162167
var buf = state.momentum_buffer;
163168

164169
if (buf is null) {
165170
buf = grad.clone().detach().DetachFromDisposeScope();
166171
state.momentum_buffer = buf;
167172
} else {
168-
buf.mul_(momentum).add_(grad, alpha: (1 - dampening));
173+
buf.mul_(momentum_scalar!);
174+
if (need_dampening) buf.add_(grad, alpha: dampening_bar_scalar!);
169175
}
170176

171177
if (nesterov) {
172-
grad = grad.add(buf, alpha: momentum);
178+
grad = grad.add(buf, alpha: momentum_scalar!);
173179
} else {
174180
grad = buf;
175181
}
176182

177183
state.momentum_buffer = buf;
178184
}
179185

180-
var alpha = maximize ? lr : -lr;
181-
param.add_(grad, alpha: alpha);
186+
param.add_(grad, alpha: signed_lr_scalar);
182187

183188
}
189+
#nullable disable
184190
}, closure);
185191
}
186192

0 commit comments

Comments
 (0)