Skip to content

Commit 156a2a9

Browse files
Safe update of params in update_grid during training.
1 parent d61b2b5 commit 156a2a9

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

pytorch_forecasting/layers/_kan/_kan_layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,5 +233,8 @@ def get_grid(num_interval):
233233
return grid
234234

235235
grid = get_grid(num_interval)
236-
self.grid.data = extend_grid(grid, k_extend=self.k)
237-
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)
236+
with torch.no_grad():
237+
new_grid = extend_grid(grid, k_extend=self.k)
238+
self.grid.copy_(new_grid)
239+
new_coef = curve2coef(x_pos, y_eval, self.grid, self.k)
240+
self.coef.copy_(new_coef) # this is a safe inplace copy

0 commit comments

Comments
 (0)