9
9
import torch .nn as nn
10
10
import torch .optim
11
11
import tqdm
12
+ from tensordict import from_module
12
13
13
14
from torchrl .collectors import SyncDataCollector
14
15
from torchrl .objectives import A2CLoss
15
16
from torchrl .objectives .value .advantages import GAE
16
-
17
17
from torchrl .record .loggers import generate_exp_name , get_logger
18
18
from utils_atari import make_parallel_env , make_ppo_models
19
19
@@ -36,7 +36,9 @@ def __init__(self, params, **kwargs):
36
36
37
37
38
38
class A3CWorker (mp .Process ):
39
- def __init__ (self , name , cfg , global_actor , global_critic , optimizer , logger = None ):
39
+ def __init__ (
40
+ self , name , cfg , global_actor , global_critic , optimizer , use_logger = False
41
+ ):
40
42
super ().__init__ ()
41
43
self .name = name
42
44
self .cfg = cfg
@@ -55,8 +57,24 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
55
57
56
58
self .global_actor = global_actor
57
59
self .global_critic = global_critic
58
- self .local_actor = deepcopy (global_actor )
59
- self .local_critic = deepcopy (global_critic )
60
+ self .local_actor = self .copy_model (global_actor )
61
+ self .local_critic = self .copy_model (global_critic )
62
+
63
+ logger = None
64
+ if use_logger and cfg .logger .backend :
65
+ exp_name = generate_exp_name (
66
+ "A3C" , f"{ cfg .logger .exp_name } _{ cfg .env .env_name } "
67
+ )
68
+ logger = get_logger (
69
+ cfg .logger .backend ,
70
+ logger_name = "a3c" ,
71
+ experiment_name = exp_name ,
72
+ wandb_kwargs = {
73
+ "config" : dict (cfg ),
74
+ "project" : cfg .logger .project_name ,
75
+ "group" : cfg .logger .group_name ,
76
+ },
77
+ )
60
78
61
79
self .logger = logger
62
80
@@ -79,6 +97,21 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
79
97
self .adv_module .set_keys (done = "end-of-life" , terminated = "end-of-life" )
80
98
self .loss_module .set_keys (done = "end-of-life" , terminated = "end-of-life" )
81
99
100
+ def copy_model (self , model ):
101
+ td_params = from_module (model )
102
+ td_new_params = td_params .data .clone ()
103
+ td_new_params = td_new_params .apply (
104
+ lambda p0 , p1 : torch .nn .Parameter (p0 )
105
+ if isinstance (p1 , torch .nn .Parameter )
106
+ else p0 ,
107
+ td_params ,
108
+ )
109
+ with td_params .data .to ("meta" ).to_module (model ):
110
+ # we don't copy any param here
111
+ new_model = deepcopy (model )
112
+ td_new_params .to_module (new_model )
113
+ return new_model
114
+
82
115
def update (self , batch , max_grad_norm = None ):
83
116
if max_grad_norm is None :
84
117
max_grad_norm = self .cfg .optim .max_grad_norm
@@ -184,7 +217,7 @@ def run(self):
184
217
185
218
# Logging only on the first worker in the dashboard.
186
219
# Alternatively, you can use a distributed logger, or aggregate metrics from all workers.
187
- if self .logger and self . name == "worker_0" :
220
+ if self .logger :
188
221
for key , value in metrics_to_log .items ():
189
222
self .logger .log_scalar (key , value , collected_frames )
190
223
collector .shutdown ()
@@ -202,24 +235,15 @@ def main(cfg: DictConfig): # noqa: F821
202
235
203
236
num_workers = cfg .multiprocessing .num_workers
204
237
205
- if num_workers is None :
206
- num_workers = mp .cpu_count ()
207
- logger = None
208
- if cfg .logger .backend :
209
- exp_name = generate_exp_name ("A3C" , f"{ cfg .logger .exp_name } _{ cfg .env .env_name } " )
210
- logger = get_logger (
211
- cfg .logger .backend ,
212
- logger_name = "a3c" ,
213
- experiment_name = exp_name ,
214
- wandb_kwargs = {
215
- "config" : dict (cfg ),
216
- "project" : cfg .logger .project_name ,
217
- "group" : cfg .logger .group_name ,
218
- },
219
- )
220
-
221
238
workers = [
222
- A3CWorker (f"worker_{ i } " , cfg , global_actor , global_critic , optimizer , logger )
239
+ A3CWorker (
240
+ f"worker_{ i } " ,
241
+ cfg ,
242
+ global_actor ,
243
+ global_critic ,
244
+ optimizer ,
245
+ use_logger = i == 0 ,
246
+ )
223
247
for i in range (num_workers )
224
248
]
225
249
[w .start () for w in workers ]
0 commit comments