Skip to content

Commit a6eb18d

Browse files
committed
amend
1 parent b49e35a commit a6eb18d

File tree

1 file changed

+46
-22
lines changed

1 file changed

+46
-22
lines changed

sota-implementations/a3c/a3c_atari.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import torch.nn as nn
1010
import torch.optim
1111
import tqdm
12+
from tensordict import from_module
1213

1314
from torchrl.collectors import SyncDataCollector
1415
from torchrl.objectives import A2CLoss
1516
from torchrl.objectives.value.advantages import GAE
16-
1717
from torchrl.record.loggers import generate_exp_name, get_logger
1818
from utils_atari import make_parallel_env, make_ppo_models
1919

@@ -36,7 +36,9 @@ def __init__(self, params, **kwargs):
3636

3737

3838
class A3CWorker(mp.Process):
39-
def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=None):
39+
def __init__(
40+
self, name, cfg, global_actor, global_critic, optimizer, use_logger=False
41+
):
4042
super().__init__()
4143
self.name = name
4244
self.cfg = cfg
@@ -55,8 +57,24 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
5557

5658
self.global_actor = global_actor
5759
self.global_critic = global_critic
58-
self.local_actor = deepcopy(global_actor)
59-
self.local_critic = deepcopy(global_critic)
60+
self.local_actor = self.copy_model(global_actor)
61+
self.local_critic = self.copy_model(global_critic)
62+
63+
logger = None
64+
if use_logger and cfg.logger.backend:
65+
exp_name = generate_exp_name(
66+
"A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}"
67+
)
68+
logger = get_logger(
69+
cfg.logger.backend,
70+
logger_name="a3c",
71+
experiment_name=exp_name,
72+
wandb_kwargs={
73+
"config": dict(cfg),
74+
"project": cfg.logger.project_name,
75+
"group": cfg.logger.group_name,
76+
},
77+
)
6078

6179
self.logger = logger
6280

@@ -79,6 +97,21 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
7997
self.adv_module.set_keys(done="end-of-life", terminated="end-of-life")
8098
self.loss_module.set_keys(done="end-of-life", terminated="end-of-life")
8199

100+
def copy_model(self, model):
101+
td_params = from_module(model)
102+
td_new_params = td_params.data.clone()
103+
td_new_params = td_new_params.apply(
104+
lambda p0, p1: torch.nn.Parameter(p0)
105+
if isinstance(p1, torch.nn.Parameter)
106+
else p0,
107+
td_params,
108+
)
109+
with td_params.data.to("meta").to_module(model):
110+
# we don't copy any param here
111+
new_model = deepcopy(model)
112+
td_new_params.to_module(new_model)
113+
return new_model
114+
82115
def update(self, batch, max_grad_norm=None):
83116
if max_grad_norm is None:
84117
max_grad_norm = self.cfg.optim.max_grad_norm
@@ -184,7 +217,7 @@ def run(self):
184217

185218
# Logging only on the first worker in the dashboard.
186219
# Alternatively, you can use a distributed logger, or aggregate metrics from all workers.
187-
if self.logger and self.name == "worker_0":
220+
if self.logger:
188221
for key, value in metrics_to_log.items():
189222
self.logger.log_scalar(key, value, collected_frames)
190223
collector.shutdown()
@@ -202,24 +235,15 @@ def main(cfg: DictConfig): # noqa: F821
202235

203236
num_workers = cfg.multiprocessing.num_workers
204237

205-
if num_workers is None:
206-
num_workers = mp.cpu_count()
207-
logger = None
208-
if cfg.logger.backend:
209-
exp_name = generate_exp_name("A3C", f"{cfg.logger.exp_name}_{cfg.env.env_name}")
210-
logger = get_logger(
211-
cfg.logger.backend,
212-
logger_name="a3c",
213-
experiment_name=exp_name,
214-
wandb_kwargs={
215-
"config": dict(cfg),
216-
"project": cfg.logger.project_name,
217-
"group": cfg.logger.group_name,
218-
},
219-
)
220-
221238
workers = [
222-
A3CWorker(f"worker_{i}", cfg, global_actor, global_critic, optimizer, logger)
239+
A3CWorker(
240+
f"worker_{i}",
241+
cfg,
242+
global_actor,
243+
global_critic,
244+
optimizer,
245+
use_logger=i == 0,
246+
)
223247
for i in range(num_workers)
224248
]
225249
[w.start() for w in workers]

0 commit comments

Comments
 (0)