99import torch .nn as nn
1010import torch .optim
1111import tqdm
12+ from tensordict import from_module
1213
1314from torchrl .collectors import SyncDataCollector
1415from torchrl .objectives import A2CLoss
1516from torchrl .objectives .value .advantages import GAE
16-
1717from torchrl .record .loggers import generate_exp_name , get_logger
1818from utils_atari import make_parallel_env , make_ppo_models
1919
@@ -36,7 +36,9 @@ def __init__(self, params, **kwargs):
3636
3737
3838class A3CWorker (mp .Process ):
39- def __init__ (self , name , cfg , global_actor , global_critic , optimizer , logger = None ):
39+ def __init__ (
40+ self , name , cfg , global_actor , global_critic , optimizer , use_logger = False
41+ ):
4042 super ().__init__ ()
4143 self .name = name
4244 self .cfg = cfg
@@ -55,8 +57,24 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
5557
5658 self .global_actor = global_actor
5759 self .global_critic = global_critic
58- self .local_actor = deepcopy (global_actor )
59- self .local_critic = deepcopy (global_critic )
60+ self .local_actor = self .copy_model (global_actor )
61+ self .local_critic = self .copy_model (global_critic )
62+
63+ logger = None
64+ if use_logger and cfg .logger .backend :
65+ exp_name = generate_exp_name (
66+ "A3C" , f"{ cfg .logger .exp_name } _{ cfg .env .env_name } "
67+ )
68+ logger = get_logger (
69+ cfg .logger .backend ,
70+ logger_name = "a3c" ,
71+ experiment_name = exp_name ,
72+ wandb_kwargs = {
73+ "config" : dict (cfg ),
74+ "project" : cfg .logger .project_name ,
75+ "group" : cfg .logger .group_name ,
76+ },
77+ )
6078
6179 self .logger = logger
6280
@@ -79,6 +97,21 @@ def __init__(self, name, cfg, global_actor, global_critic, optimizer, logger=Non
7997 self .adv_module .set_keys (done = "end-of-life" , terminated = "end-of-life" )
8098 self .loss_module .set_keys (done = "end-of-life" , terminated = "end-of-life" )
8199
100+ def copy_model (self , model ):
101+ td_params = from_module (model )
102+ td_new_params = td_params .data .clone ()
103+ td_new_params = td_new_params .apply (
104+ lambda p0 , p1 : torch .nn .Parameter (p0 )
105+ if isinstance (p1 , torch .nn .Parameter )
106+ else p0 ,
107+ td_params ,
108+ )
109+ with td_params .data .to ("meta" ).to_module (model ):
110+ # we don't copy any param here
111+ new_model = deepcopy (model )
112+ td_new_params .to_module (new_model )
113+ return new_model
114+
82115 def update (self , batch , max_grad_norm = None ):
83116 if max_grad_norm is None :
84117 max_grad_norm = self .cfg .optim .max_grad_norm
@@ -184,7 +217,7 @@ def run(self):
184217
185218 # Logging only on the first worker in the dashboard.
186219 # Alternatively, you can use a distributed logger, or aggregate metrics from all workers.
187- if self .logger and self . name == "worker_0" :
220+ if self .logger :
188221 for key , value in metrics_to_log .items ():
189222 self .logger .log_scalar (key , value , collected_frames )
190223 collector .shutdown ()
@@ -202,24 +235,15 @@ def main(cfg: DictConfig): # noqa: F821
202235
203236 num_workers = cfg .multiprocessing .num_workers
204237
205- if num_workers is None :
206- num_workers = mp .cpu_count ()
207- logger = None
208- if cfg .logger .backend :
209- exp_name = generate_exp_name ("A3C" , f"{ cfg .logger .exp_name } _{ cfg .env .env_name } " )
210- logger = get_logger (
211- cfg .logger .backend ,
212- logger_name = "a3c" ,
213- experiment_name = exp_name ,
214- wandb_kwargs = {
215- "config" : dict (cfg ),
216- "project" : cfg .logger .project_name ,
217- "group" : cfg .logger .group_name ,
218- },
219- )
220-
221238 workers = [
222- A3CWorker (f"worker_{ i } " , cfg , global_actor , global_critic , optimizer , logger )
239+ A3CWorker (
240+ f"worker_{ i } " ,
241+ cfg ,
242+ global_actor ,
243+ global_critic ,
244+ optimizer ,
245+ use_logger = i == 0 ,
246+ )
223247 for i in range (num_workers )
224248 ]
225249 [w .start () for w in workers ]
0 commit comments