Skip to content

Commit 4b7c495

Browse files
authored
feat: replace no grad with inference mode (#331)
replaces torch.no_grad with torch.inference_mode 1/ should gain even more speedup as version tracking / counters, cf https://docs.pytorch.org/docs/2.8/generated/torch.autograd.grad_mode.inference_mode.html . Note this additionally disables AD in forward mode, but I guess this is not being used anywhere 2/ replacet the set call with context managers, to not pollute the global state
1 parent 21b300b commit 4b7c495

File tree

1 file changed

+99
-93
lines changed

1 file changed

+99
-93
lines changed

src/anemoi/inference/runner.py

Lines changed: 99 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -594,112 +594,118 @@ def forecast(
594594
Any
595595
The forecasted state.
596596
"""
597-
self.model.eval()
597+
# NOTE we are not using decorator of the top level function as we anticipate lazy torch load
598+
with torch.inference_mode():
599+
self.model.eval()
598600

599-
torch.set_grad_enabled(False)
601+
# Create pytorch input tensor
602+
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(
603+
self.device
604+
)
600605

601-
# Create pytorch input tensor
602-
input_tensor_torch = torch.from_numpy(np.swapaxes(input_tensor_numpy, -2, -1)[np.newaxis, ...]).to(self.device)
606+
lead_time = to_timedelta(lead_time)
603607

604-
lead_time = to_timedelta(lead_time)
605-
606-
new_state = input_state.copy() # We should not modify the input state
607-
new_state["fields"] = dict()
608-
new_state["step"] = to_timedelta(0)
609-
610-
start = input_state["date"]
611-
612-
# The variable `check` is used to keep track of which variables have been updated
613-
# In the input tensor. `reset` is used to reset `check` to False except
614-
# when the values are of the constant in time variables
615-
616-
reset = np.full((input_tensor_torch.shape[-1],), False)
617-
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
618-
typed_variables = self.checkpoint.typed_variables
619-
for variable, i in variable_to_input_tensor_index.items():
620-
if typed_variables[variable].is_constant_in_time:
621-
reset[i] = True
622-
623-
check = reset.copy()
624-
625-
if self.verbosity > 0:
626-
self._print_input_tensor("First input tensor", input_tensor_torch)
627-
628-
for s, (step, date, next_date, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)):
629-
title = f"Forecasting step {step} ({date})"
630-
631-
new_state["date"] = date
632-
new_state["previous_step"] = new_state.get("step")
633-
new_state["step"] = step
634-
635-
if self.trace:
636-
self.trace.write_input_tensor(
637-
date, s, input_tensor_torch.cpu().numpy(), variable_to_input_tensor_index, self.checkpoint.timestep
638-
)
608+
new_state = input_state.copy() # We should not modify the input state
609+
new_state["fields"] = dict()
610+
new_state["step"] = to_timedelta(0)
639611

640-
# Predict next state of atmosphere
641-
with (
642-
torch.autocast(device_type=self.device.type, dtype=self.autocast),
643-
ProfilingLabel("Predict step", self.use_profiler),
644-
Timer(title),
645-
):
646-
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=date)
612+
start = input_state["date"]
647613

648-
output = torch.squeeze(y_pred, dim=(0, 1)) # shape: (values, variables)
614+
# The variable `check` is used to keep track of which variables have been updated
615+
# In the input tensor. `reset` is used to reset `check` to False except
616+
# when the values are of the constant in time variables
649617

650-
# Update state
651-
with ProfilingLabel("Updating state (CPU)", self.use_profiler):
652-
for i in range(output.shape[1]):
653-
new_state["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
618+
reset = np.full((input_tensor_torch.shape[-1],), False)
619+
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
620+
typed_variables = self.checkpoint.typed_variables
621+
for variable, i in variable_to_input_tensor_index.items():
622+
if typed_variables[variable].is_constant_in_time:
623+
reset[i] = True
654624

655-
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
656-
self._print_output_tensor("Output tensor", output.cpu().numpy())
625+
check = reset.copy()
657626

658-
if self.trace:
659-
self.trace.write_output_tensor(
660-
date,
661-
s,
662-
output.cpu().numpy(),
663-
self.checkpoint.output_tensor_index_to_variable,
664-
self.checkpoint.timestep,
665-
)
627+
if self.verbosity > 0:
628+
self._print_input_tensor("First input tensor", input_tensor_torch)
666629

667-
yield new_state
630+
for s, (step, date, next_date, is_last_step) in enumerate(self.forecast_stepper(start, lead_time)):
631+
title = f"Forecasting step {step} ({date})"
668632

669-
# No need to prepare next input tensor if we are at the last step
670-
if is_last_step:
671-
break
633+
new_state["date"] = date
634+
new_state["previous_step"] = new_state.get("step")
635+
new_state["step"] = step
672636

673-
# Update tensor for next iteration
674-
with ProfilingLabel("Update tensor for next step", self.use_profiler):
675-
check[:] = reset
676637
if self.trace:
677-
self.trace.reset_sources(reset, self.checkpoint.variable_to_input_tensor_index)
678-
679-
input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)
680-
681-
del y_pred # Recover memory
638+
self.trace.write_input_tensor(
639+
date,
640+
s,
641+
input_tensor_torch.cpu().numpy(),
642+
variable_to_input_tensor_index,
643+
self.checkpoint.timestep,
644+
)
645+
646+
# Predict next state of atmosphere
647+
with (
648+
torch.autocast(device_type=self.device.type, dtype=self.autocast),
649+
ProfilingLabel("Predict step", self.use_profiler),
650+
Timer(title),
651+
):
652+
y_pred = self.predict_step(self.model, input_tensor_torch, fcstep=s, step=step, date=date)
653+
654+
output = torch.squeeze(y_pred, dim=(0, 1)) # shape: (values, variables)
655+
656+
# Update state
657+
with ProfilingLabel("Updating state (CPU)", self.use_profiler):
658+
for i in range(output.shape[1]):
659+
new_state["fields"][self.checkpoint.output_tensor_index_to_variable[i]] = output[:, i]
660+
661+
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
662+
self._print_output_tensor("Output tensor", output.cpu().numpy())
682663

683-
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
684-
input_tensor_torch, new_state, next_date, check
685-
)
686-
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
687-
input_tensor_torch, new_state, next_date, check
688-
)
689-
690-
if not check.all():
691-
# Not all variables have been updated
692-
missing = []
693-
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
694-
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
695-
for i in range(check.shape[-1]):
696-
if not check[i]:
697-
missing.append(mapping[i])
698-
699-
raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")
700-
701-
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
702-
self._print_input_tensor("Next input tensor", input_tensor_torch)
664+
if self.trace:
665+
self.trace.write_output_tensor(
666+
date,
667+
s,
668+
output.cpu().numpy(),
669+
self.checkpoint.output_tensor_index_to_variable,
670+
self.checkpoint.timestep,
671+
)
672+
673+
yield new_state
674+
675+
# No need to prepare next input tensor if we are at the last step
676+
if is_last_step:
677+
break
678+
679+
# Update tensor for next iteration
680+
with ProfilingLabel("Update tensor for next step", self.use_profiler):
681+
check[:] = reset
682+
if self.trace:
683+
self.trace.reset_sources(reset, self.checkpoint.variable_to_input_tensor_index)
684+
685+
input_tensor_torch = self.copy_prognostic_fields_to_input_tensor(input_tensor_torch, y_pred, check)
686+
687+
del y_pred # Recover memory
688+
689+
input_tensor_torch = self.add_dynamic_forcings_to_input_tensor(
690+
input_tensor_torch, new_state, next_date, check
691+
)
692+
input_tensor_torch = self.add_boundary_forcings_to_input_tensor(
693+
input_tensor_torch, new_state, next_date, check
694+
)
695+
696+
if not check.all():
697+
# Not all variables have been updated
698+
missing = []
699+
variable_to_input_tensor_index = self.checkpoint.variable_to_input_tensor_index
700+
mapping = {v: k for k, v in variable_to_input_tensor_index.items()}
701+
for i in range(check.shape[-1]):
702+
if not check[i]:
703+
missing.append(mapping[i])
704+
705+
raise ValueError(f"Missing variables in input tensor: {sorted(missing)}")
706+
707+
if (s == 0 and self.verbosity > 0) or self.verbosity > 1:
708+
self._print_input_tensor("Next input tensor", input_tensor_torch)
703709

704710
def copy_prognostic_fields_to_input_tensor(
705711
self, input_tensor_torch: torch.Tensor, y_pred: torch.Tensor, check: BoolArray

0 commit comments

Comments
 (0)