1
- from dataclasses import dataclass
2
- from typing import Callable
1
+ from dataclasses import dataclass , fields
3
2
3
+ from typing import Callable , List , Optional , Union
4
+
5
+ from ray .rllib .connectors .connector_v2 import ConnectorV2
4
6
from ray .rllib .core .learner .differentiable_learner import DifferentiableLearner
7
+ from ray .rllib .core .rl_module .rl_module import RLModule
8
+ from ray .rllib .core .rl_module .multi_rl_module import MultiRLModuleSpec
9
+ from ray .rllib .utils .typing import ModuleID
5
10
6
11
7
12
@dataclass
@@ -13,6 +18,16 @@ class DifferentiableLearnerConfig:
13
18
# The `DifferentiableLearner` class. Must be derived from `DifferentiableLearner`.
14
19
learner_class : Callable
15
20
21
+ learner_connector : Optional [
22
+ Callable [["RLModule" ], Union ["ConnectorV2" , List ["ConnectorV2" ]]]
23
+ ] = None
24
+
25
+ add_default_connectors_to_learner_pipeline : bool = True
26
+
27
+ is_multi_agent : bool = False
28
+
29
+ policies_to_update : List [ModuleID ] = None
30
+
16
31
# The learning rate to use for the nested update. Note, in the default case this
17
32
# learning rate is only used to update parameters in a functional form, i.e. the
18
33
# `RLModule`'s stateful parameters are only updated in the `MetaLearner`. Different
@@ -45,3 +60,88 @@ def __post_init__(self):
45
60
"`learner_class` must be a subclass of `DifferentiableLearner "
46
61
f"but is { self .learner_class } ."
47
62
)
63
+
64
+ def build_learner_connector (
65
+ self ,
66
+ input_observation_space ,
67
+ input_action_space ,
68
+ device = None ,
69
+ ):
70
+ from ray .rllib .connectors .learner import (
71
+ AddColumnsFromEpisodesToTrainBatch ,
72
+ AddObservationsFromEpisodesToBatch ,
73
+ AddStatesFromEpisodesToBatch ,
74
+ AddTimeDimToBatchAndZeroPad ,
75
+ AgentToModuleMapping ,
76
+ BatchIndividualItems ,
77
+ LearnerConnectorPipeline ,
78
+ NumpyToTensor ,
79
+ )
80
+
81
+ custom_connectors = []
82
+ # Create a learner connector pipeline (including RLlib's default
83
+ # learner connector piece) and return it.
84
+ if self .learner_connector is not None :
85
+ val_ = self .learner_connector (
86
+ input_observation_space ,
87
+ input_action_space ,
88
+ # device, # TODO (sven): Also pass device into custom builder.
89
+ )
90
+
91
+ from ray .rllib .connectors .connector_v2 import ConnectorV2
92
+
93
+ # ConnectorV2 (piece or pipeline).
94
+ if isinstance (val_ , ConnectorV2 ):
95
+ custom_connectors = [val_ ]
96
+ # Sequence of individual ConnectorV2 pieces.
97
+ elif isinstance (val_ , (list , tuple )):
98
+ custom_connectors = list (val_ )
99
+ # Unsupported return value.
100
+ else :
101
+ raise ValueError (
102
+ "`AlgorithmConfig.training(learner_connector=..)` must return "
103
+ "a ConnectorV2 object or a list thereof (to be added to a "
104
+ f"pipeline)! Your function returned { val_ } ."
105
+ )
106
+
107
+ pipeline = LearnerConnectorPipeline (
108
+ connectors = custom_connectors ,
109
+ input_observation_space = input_observation_space ,
110
+ input_action_space = input_action_space ,
111
+ )
112
+ if self .add_default_connectors_to_learner_pipeline :
113
+ # Append OBS handling.
114
+ pipeline .append (
115
+ AddObservationsFromEpisodesToBatch (as_learner_connector = True )
116
+ )
117
+ # Append all other columns handling.
118
+ pipeline .append (AddColumnsFromEpisodesToTrainBatch ())
119
+ # Append time-rank handler.
120
+ pipeline .append (AddTimeDimToBatchAndZeroPad (as_learner_connector = True ))
121
+ # Append STATE_IN/STATE_OUT handler.
122
+ pipeline .append (AddStatesFromEpisodesToBatch (as_learner_connector = True ))
123
+ # If multi-agent -> Map from AgentID-based data to ModuleID based data.
124
+ if self .is_multi_agent :
125
+ pipeline .append (
126
+ AgentToModuleMapping (
127
+ rl_module_specs = (
128
+ self .rl_module_spec .rl_module_specs
129
+ if isinstance (self .rl_module_spec , MultiRLModuleSpec )
130
+ else set (self .policies )
131
+ ),
132
+ agent_to_module_mapping_fn = self .policy_mapping_fn ,
133
+ )
134
+ )
135
+ # Batch all data.
136
+ pipeline .append (BatchIndividualItems (multi_agent = self .is_multi_agent ))
137
+ # Convert to Tensors.
138
+ pipeline .append (NumpyToTensor (as_learner_connector = True , device = device ))
139
+ return pipeline
140
+
141
+ def update_from_kwargs (self , ** kwargs ):
142
+ """Sets all slots with values defined in `kwargs`."""
143
+ # Get all field names (i.e., slot names).
144
+ field_names = {f .name for f in fields (self )}
145
+ for key , value in kwargs .items ():
146
+ if key in field_names :
147
+ setattr (self , key , value )
0 commit comments