@@ -594,112 +594,118 @@ def forecast(
594
594
Any
595
595
The forecasted state.
596
596
"""
597
- self .model .eval ()
597
+ # NOTE we are not using decorator of the top level function as we anticipate lazy torch load
598
+ with torch .inference_mode ():
599
+ self .model .eval ()
598
600
599
- torch .set_grad_enabled (False )
601
+ # Create pytorch input tensor
602
+ input_tensor_torch = torch .from_numpy (np .swapaxes (input_tensor_numpy , - 2 , - 1 )[np .newaxis , ...]).to (
603
+ self .device
604
+ )
600
605
601
- # Create pytorch input tensor
602
- input_tensor_torch = torch .from_numpy (np .swapaxes (input_tensor_numpy , - 2 , - 1 )[np .newaxis , ...]).to (self .device )
606
+ lead_time = to_timedelta (lead_time )
603
607
604
- lead_time = to_timedelta (lead_time )
605
-
606
- new_state = input_state .copy () # We should not modify the input state
607
- new_state ["fields" ] = dict ()
608
- new_state ["step" ] = to_timedelta (0 )
609
-
610
- start = input_state ["date" ]
611
-
612
- # The variable `check` is used to keep track of which variables have been updated
613
- # In the input tensor. `reset` is used to reset `check` to False except
614
- # when the values are of the constant in time variables
615
-
616
- reset = np .full ((input_tensor_torch .shape [- 1 ],), False )
617
- variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
618
- typed_variables = self .checkpoint .typed_variables
619
- for variable , i in variable_to_input_tensor_index .items ():
620
- if typed_variables [variable ].is_constant_in_time :
621
- reset [i ] = True
622
-
623
- check = reset .copy ()
624
-
625
- if self .verbosity > 0 :
626
- self ._print_input_tensor ("First input tensor" , input_tensor_torch )
627
-
628
- for s , (step , date , next_date , is_last_step ) in enumerate (self .forecast_stepper (start , lead_time )):
629
- title = f"Forecasting step { step } ({ date } )"
630
-
631
- new_state ["date" ] = date
632
- new_state ["previous_step" ] = new_state .get ("step" )
633
- new_state ["step" ] = step
634
-
635
- if self .trace :
636
- self .trace .write_input_tensor (
637
- date , s , input_tensor_torch .cpu ().numpy (), variable_to_input_tensor_index , self .checkpoint .timestep
638
- )
608
+ new_state = input_state .copy () # We should not modify the input state
609
+ new_state ["fields" ] = dict ()
610
+ new_state ["step" ] = to_timedelta (0 )
639
611
640
- # Predict next state of atmosphere
641
- with (
642
- torch .autocast (device_type = self .device .type , dtype = self .autocast ),
643
- ProfilingLabel ("Predict step" , self .use_profiler ),
644
- Timer (title ),
645
- ):
646
- y_pred = self .predict_step (self .model , input_tensor_torch , fcstep = s , step = step , date = date )
612
+ start = input_state ["date" ]
647
613
648
- output = torch .squeeze (y_pred , dim = (0 , 1 )) # shape: (values, variables)
614
+ # The variable `check` is used to keep track of which variables have been updated
615
+ # In the input tensor. `reset` is used to reset `check` to False except
616
+ # when the values are of the constant in time variables
649
617
650
- # Update state
651
- with ProfilingLabel ("Updating state (CPU)" , self .use_profiler ):
652
- for i in range (output .shape [1 ]):
653
- new_state ["fields" ][self .checkpoint .output_tensor_index_to_variable [i ]] = output [:, i ]
618
+ reset = np .full ((input_tensor_torch .shape [- 1 ],), False )
619
+ variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
620
+ typed_variables = self .checkpoint .typed_variables
621
+ for variable , i in variable_to_input_tensor_index .items ():
622
+ if typed_variables [variable ].is_constant_in_time :
623
+ reset [i ] = True
654
624
655
- if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
656
- self ._print_output_tensor ("Output tensor" , output .cpu ().numpy ())
625
+ check = reset .copy ()
657
626
658
- if self .trace :
659
- self .trace .write_output_tensor (
660
- date ,
661
- s ,
662
- output .cpu ().numpy (),
663
- self .checkpoint .output_tensor_index_to_variable ,
664
- self .checkpoint .timestep ,
665
- )
627
+ if self .verbosity > 0 :
628
+ self ._print_input_tensor ("First input tensor" , input_tensor_torch )
666
629
667
- yield new_state
630
+ for s , (step , date , next_date , is_last_step ) in enumerate (self .forecast_stepper (start , lead_time )):
631
+ title = f"Forecasting step { step } ({ date } )"
668
632
669
- # No need to prepare next input tensor if we are at the last step
670
- if is_last_step :
671
- break
633
+ new_state [ "date" ] = date
634
+ new_state [ "previous_step" ] = new_state . get ( "step" )
635
+ new_state [ "step" ] = step
672
636
673
- # Update tensor for next iteration
674
- with ProfilingLabel ("Update tensor for next step" , self .use_profiler ):
675
- check [:] = reset
676
637
if self .trace :
677
- self .trace .reset_sources (reset , self .checkpoint .variable_to_input_tensor_index )
678
-
679
- input_tensor_torch = self .copy_prognostic_fields_to_input_tensor (input_tensor_torch , y_pred , check )
680
-
681
- del y_pred # Recover memory
638
+ self .trace .write_input_tensor (
639
+ date ,
640
+ s ,
641
+ input_tensor_torch .cpu ().numpy (),
642
+ variable_to_input_tensor_index ,
643
+ self .checkpoint .timestep ,
644
+ )
645
+
646
+ # Predict next state of atmosphere
647
+ with (
648
+ torch .autocast (device_type = self .device .type , dtype = self .autocast ),
649
+ ProfilingLabel ("Predict step" , self .use_profiler ),
650
+ Timer (title ),
651
+ ):
652
+ y_pred = self .predict_step (self .model , input_tensor_torch , fcstep = s , step = step , date = date )
653
+
654
+ output = torch .squeeze (y_pred , dim = (0 , 1 )) # shape: (values, variables)
655
+
656
+ # Update state
657
+ with ProfilingLabel ("Updating state (CPU)" , self .use_profiler ):
658
+ for i in range (output .shape [1 ]):
659
+ new_state ["fields" ][self .checkpoint .output_tensor_index_to_variable [i ]] = output [:, i ]
660
+
661
+ if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
662
+ self ._print_output_tensor ("Output tensor" , output .cpu ().numpy ())
682
663
683
- input_tensor_torch = self .add_dynamic_forcings_to_input_tensor (
684
- input_tensor_torch , new_state , next_date , check
685
- )
686
- input_tensor_torch = self .add_boundary_forcings_to_input_tensor (
687
- input_tensor_torch , new_state , next_date , check
688
- )
689
-
690
- if not check .all ():
691
- # Not all variables have been updated
692
- missing = []
693
- variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
694
- mapping = {v : k for k , v in variable_to_input_tensor_index .items ()}
695
- for i in range (check .shape [- 1 ]):
696
- if not check [i ]:
697
- missing .append (mapping [i ])
698
-
699
- raise ValueError (f"Missing variables in input tensor: { sorted (missing )} " )
700
-
701
- if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
702
- self ._print_input_tensor ("Next input tensor" , input_tensor_torch )
664
+ if self .trace :
665
+ self .trace .write_output_tensor (
666
+ date ,
667
+ s ,
668
+ output .cpu ().numpy (),
669
+ self .checkpoint .output_tensor_index_to_variable ,
670
+ self .checkpoint .timestep ,
671
+ )
672
+
673
+ yield new_state
674
+
675
+ # No need to prepare next input tensor if we are at the last step
676
+ if is_last_step :
677
+ break
678
+
679
+ # Update tensor for next iteration
680
+ with ProfilingLabel ("Update tensor for next step" , self .use_profiler ):
681
+ check [:] = reset
682
+ if self .trace :
683
+ self .trace .reset_sources (reset , self .checkpoint .variable_to_input_tensor_index )
684
+
685
+ input_tensor_torch = self .copy_prognostic_fields_to_input_tensor (input_tensor_torch , y_pred , check )
686
+
687
+ del y_pred # Recover memory
688
+
689
+ input_tensor_torch = self .add_dynamic_forcings_to_input_tensor (
690
+ input_tensor_torch , new_state , next_date , check
691
+ )
692
+ input_tensor_torch = self .add_boundary_forcings_to_input_tensor (
693
+ input_tensor_torch , new_state , next_date , check
694
+ )
695
+
696
+ if not check .all ():
697
+ # Not all variables have been updated
698
+ missing = []
699
+ variable_to_input_tensor_index = self .checkpoint .variable_to_input_tensor_index
700
+ mapping = {v : k for k , v in variable_to_input_tensor_index .items ()}
701
+ for i in range (check .shape [- 1 ]):
702
+ if not check [i ]:
703
+ missing .append (mapping [i ])
704
+
705
+ raise ValueError (f"Missing variables in input tensor: { sorted (missing )} " )
706
+
707
+ if (s == 0 and self .verbosity > 0 ) or self .verbosity > 1 :
708
+ self ._print_input_tensor ("Next input tensor" , input_tensor_torch )
703
709
704
710
def copy_prognostic_fields_to_input_tensor (
705
711
self , input_tensor_torch : torch .Tensor , y_pred : torch .Tensor , check : BoolArray
0 commit comments