Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions torchopt/optim/func/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,16 @@ def step(
if inplace is None:
inplace = self.inplace

# Step parameter only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
updates, self.optim_state = self.impl.update(
grads,
self.optim_state,
params=params,
inplace=inplace,
)
return apply_updates(params, updates, inplace=inplace)
with torch.enable_grad():
# Step parameters only
grads = torch.autograd.grad(loss, params, create_graph=True, allow_unused=True)
updates, self.optim_state = self.impl.update(
grads,
self.optim_state,
params=params,
inplace=inplace,
)
return apply_updates(params, updates, inplace=inplace)

def state_dict(self) -> OptState:
"""Extract the references of the optimizer states.
Expand Down
30 changes: 15 additions & 15 deletions torchopt/optim/meta/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,32 @@ def step(self, loss: torch.Tensor) -> None: # pylint: disable=too-many-locals
loss (torch.Tensor): The loss that is used to compute the gradients to the network
parameters.
"""
# Step parameter only
for i, (param_container, state) in enumerate(
zip(self.param_containers_groups, self.state_groups),
):
flat_params: TupleOfTensors
flat_params, container_treespec = pytree.tree_flatten_as_tuple(param_container) # type: ignore[arg-type]

if isinstance(state, UninitializedState):
state = self.impl.init(flat_params)
grads = torch.autograd.grad(
loss,
flat_params,
create_graph=True,
allow_unused=True,
)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
self.state_groups[i] = new_state
flat_new_params = apply_updates(flat_params, updates, inplace=False)

with torch.enable_grad():
# Step parameters only
grads = torch.autograd.grad(loss, flat_params, create_graph=True, allow_unused=True)
updates, new_state = self.impl.update(
grads,
state,
params=flat_params,
inplace=False,
)
flat_new_params = apply_updates(flat_params, updates, inplace=False)

new_params: ModuleTensorContainers = pytree.tree_unflatten( # type: ignore[assignment]
container_treespec,
flat_new_params,
)

self.state_groups[i] = new_state
for container, new_param in zip(param_container, new_params):
container.update(new_param)

Expand Down