diff --git a/intermediate_source/dqn_with_rnn_tutorial.py b/intermediate_source/dqn_with_rnn_tutorial.py index bcc484f0a0..462415dcc7 100644 --- a/intermediate_source/dqn_with_rnn_tutorial.py +++ b/intermediate_source/dqn_with_rnn_tutorial.py @@ -342,7 +342,9 @@ # will return a new instance of the LSTM (with shared weights) that will # assume that the input data is sequential in nature. # -policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval) +from torchrl.modules import set_recurrent_mode + +policy = Seq(feature, lstm, mlp, qval) ###################################################################### # Because we still have a couple of uninitialized parameters we should @@ -389,7 +391,10 @@ # For the sake of efficiency, we're only running a few thousands iterations # here. In a real setting, the total number of frames should be set to 1M. # -collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device) + +collector = SyncDataCollector( + env, stoch_policy, frames_per_batch=50, total_frames=200, device=device +) rb = TensorDictReplayBuffer( storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10 ) @@ -422,7 +427,8 @@ rb.extend(data.unsqueeze(0).to_tensordict().cpu()) for _ in range(utd): s = rb.sample().to(device, non_blocking=True) - loss_vals = loss_fn(s) + with set_recurrent_mode(True): + loss_vals = loss_fn(s) loss_vals["loss"].backward() optim.step() optim.zero_grad() @@ -464,5 +470,5 @@ # # Further Reading # --------------- -# +# # - The TorchRL documentation can be found `here `_.