diff --git a/alf/algorithms/algorithm.py b/alf/algorithms/algorithm.py index 294e79f35..6eac112f5 100644 --- a/alf/algorithms/algorithm.py +++ b/alf/algorithms/algorithm.py @@ -1888,8 +1888,9 @@ def _update(self, experience, batch_info, weight): weight (float): weight for this batch. Loss will be multiplied with this weight before calculating gradient. """ - with torch.cuda.amp.autocast(self._config.enable_amp, - dtype=self._config.amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._config.enable_amp, + dtype=self._config.amp_dtype): train_info, loss_info = self._compute_train_info_and_loss_info( experience) @@ -2097,8 +2098,9 @@ def _hybrid_update(self, experience, batch_info, offline_experience, length = alf.nest.get_nest_size(offline_experience, dim=0) if self._RL_train: - with torch.cuda.amp.autocast(self._config.enable_amp, - dtype=self._config.amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._config.enable_amp, + dtype=self._config.amp_dtype): train_info, loss_info = self._compute_train_info_and_loss_info( experience) self._update_priority(loss_info, batch_info, diff --git a/alf/algorithms/muzero_algorithm.py b/alf/algorithms/muzero_algorithm.py index c61af5fe7..d84a92918 100644 --- a/alf/algorithms/muzero_algorithm.py +++ b/alf/algorithms/muzero_algorithm.py @@ -150,7 +150,9 @@ def predict_step(self, time_step: TimeStep, state) -> AlgStep: if self._reward_transformer is not None: time_step = time_step._replace( reward=self._reward_transformer(time_step.reward)) - with torch.cuda.amp.autocast(self._enable_amp, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enable_amp, + dtype=self._amp_dtype): latent = self._repr_learner.predict_step(time_step, state).output return self._mcts.predict_step( time_step._replace(observation=latent), state) diff --git a/alf/algorithms/muzero_representation_learner.py b/alf/algorithms/muzero_representation_learner.py index f4bb97d45..b30900436 100644 --- a/alf/algorithms/muzero_representation_learner.py +++ b/alf/algorithms/muzero_representation_learner.py @@ -305,7 +305,9 @@ def _check_data_transformer(self): transformer) def predict_step(self, time_step: TimeStep, state): - with torch.cuda.amp.autocast(self._enable_amp, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enable_amp, + dtype=self._amp_dtype): return AlgStep(output=self._model.initial_representation( time_step.observation), state=(), @@ -352,8 +354,9 @@ def _hook(grad, name): obs = alf.nest.map_structure(lambda x: x.reshape(-1, *x.shape[2:]), info.target.observation) with torch.no_grad(): - with torch.cuda.amp.autocast(self._enable_amp, - dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enable_amp, + dtype=self._amp_dtype): target_repr = self._model._representation_net(obs)[0] # [B, R+1, ...] target_repr = target_repr.reshape(-1, self._num_unroll_steps + 1, @@ -839,8 +842,9 @@ def _reanalyze1(self, game_overs = convert_device(game_overs) # 1. Reanalyze the first n1 steps to get both the updated value and policy - with torch.cuda.amp.autocast(self._enable_amp, - dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enable_amp, + dtype=self._amp_dtype): latent = self._target_model.initial_representation( exp1.observation) exp1 = exp1._replace(time_step=exp1.time_step._replace( @@ -865,8 +869,9 @@ def _reshape(x): # 2. Calculate the value of the next n2 steps so that n2-step return # can be computed. if not self._full_reanalyze: - with torch.cuda.amp.autocast(self._enable_amp, - dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enable_amp, + dtype=self._amp_dtype): model_output = self._target_model.initial_inference( exp2.observation) values2 = model_output.value.reshape(batch_size, n2) diff --git a/alf/algorithms/rl_algorithm.py b/alf/algorithms/rl_algorithm.py index 773647233..85fa49f18 100644 --- a/alf/algorithms/rl_algorithm.py +++ b/alf/algorithms/rl_algorithm.py @@ -745,8 +745,9 @@ def train_iter(self): @data_distributed_when(lambda algorithm: algorithm.on_policy) def _compute_train_info_and_loss_info_on_policy(self, unroll_length): with record_time("time/unroll"): - with torch.cuda.amp.autocast(self._config.enable_amp, - dtype=self._config.amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._config.enable_amp, + dtype=self._config.amp_dtype): experience = self.unroll(self._config.unroll_length) self.summarize_metrics() @@ -811,8 +812,10 @@ def _unroll_iter_off_policy(self): or self.get_step_metrics()[1].result() < config.num_env_steps)): unrolled = True with torch.set_grad_enabled( - config.unroll_with_grad), torch.cuda.amp.autocast( - config.enable_amp, dtype=self._config.amp_dtype): + config.unroll_with_grad), torch.amp.autocast( + 'cuda', + enabled=config.enable_amp, + dtype=self._config.amp_dtype): with record_time("time/unroll"): self.eval() # The period of performing unroll may not be an integer diff --git a/alf/layers.py b/alf/layers.py index b27803562..316e12054 100644 --- a/alf/layers.py +++ b/alf/layers.py @@ -3739,7 +3739,9 @@ def __init__(self, enabled: bool, net: nn.Module): def forward(self, input): if torch.is_autocast_enabled() and not self._enabled: input = to_float32(input) - with torch.cuda.amp.autocast(self._enabled, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enabled, + dtype=self._amp_dtype): return self._net(input) diff --git a/alf/networks/networks.py b/alf/networks/networks.py index 2f512dde8..f8d50fe98 100644 --- a/alf/networks/networks.py +++ b/alf/networks/networks.py @@ -310,7 +310,9 @@ def __init__(self, enabled: bool, net: Network): def forward(self, input, state): if torch.is_autocast_enabled() and not self._enabled: input = alf.layers.to_float32(input) - with torch.cuda.amp.autocast(self._enabled, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=self._enabled, + dtype=self._amp_dtype): return self._net(input, state) diff --git a/alf/networks/projection_networks.py b/alf/networks/projection_networks.py index 70e20b6f8..d7eee91f1 100644 --- a/alf/networks/projection_networks.py +++ b/alf/networks/projection_networks.py @@ -91,7 +91,9 @@ def forward(self, inputs, state=()): if self._disable_amp and amp_enabled: inputs = alf.layers.to_float32(inputs) amp_enabled = False - with torch.cuda.amp.autocast(amp_enabled, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=amp_enabled, + dtype=self._amp_dtype): logits, state = self._projection_layer(inputs, state) logits = logits.reshape(inputs.shape[0], *self._output_shape) if len(self._output_shape) > 1: @@ -315,7 +317,9 @@ def forward(self, inputs, state=()): if self._disable_amp and amp_enabled: inputs = alf.layers.to_float32(inputs) amp_enabled = False - with torch.cuda.amp.autocast(amp_enabled, dtype=self._amp_dtype): + with torch.amp.autocast('cuda', + enabled=amp_enabled, + dtype=self._amp_dtype): means = self._mean_transform(self._means_projection_layer(inputs)) stds = self._std_transform(self._std_projection_layer(inputs)) return self._normal_dist(means, stds), state diff --git a/alf/utils/lean_function_test.py b/alf/utils/lean_function_test.py index 8b6521c21..dde6f6ff1 100644 --- a/alf/utils/lean_function_test.py +++ b/alf/utils/lean_function_test.py @@ -102,7 +102,7 @@ def test_lean_fucntion_autocast(self): p2.data.copy_(p1) x = torch.randn((4, 3), requires_grad=True) func2 = lean_function(func2) - with torch.cuda.amp.autocast(enabled=True): + with torch.amp.autocast('cuda', enabled=True): y1 = func1(x)[0] y2 = func2(x)[0] self.assertTensorEqual(y1, y2)