Skip to content

Commit a33f592

Browse files
[RLlib] Clean up meta learning class and example files. (#52680)
1 parent 04c4a54 commit a33f592

14 files changed

+522
-253
lines changed

rllib/BUILD

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,6 +3357,29 @@ py_test(
33573357
],
33583358
)
33593359

3360+
py_test(
3361+
name = "examples/algorithms/maml_lr_supervised_learning",
3362+
size = "large",
3363+
srcs = ["examples/algorithms/maml_lr_supervised_learning.py"],
3364+
args = [
3365+
"--enable-new-api-stack",
3366+
"--as-test",
3367+
"--stop-iters=70000",
3368+
"--meta-lr=0.001",
3369+
"--meta-train-batch-size=5",
3370+
"--fine-tune-iters=10",
3371+
"--fine-tune-batch-size=5",
3372+
"--fine-tune-lr=0.01",
3373+
"--noise-std=0.0",
3374+
"--no-plot",
3375+
],
3376+
main = "examples/algorithms/maml_lr_supervised_learning.py",
3377+
tags = [
3378+
"examples",
3379+
"team:rllib",
3380+
],
3381+
)
3382+
33603383
# subdirectory: catalogs/
33613384
# ....................................
33623385
py_test(

rllib/algorithms/algorithm.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,7 @@ def setup(self, config: AlgorithmConfig) -> None:
640640
else:
641641
self.offline_data = None
642642

643-
if not self.offline_data:
643+
if self.config.is_online or not self.config.enable_env_runner_and_connector_v2:
644644
# Create a set of env runner actors via a EnvRunnerGroup.
645645
self.env_runner_group = EnvRunnerGroup(
646646
env_creator=self.env_creator,
@@ -2822,28 +2822,31 @@ def get_state(
28222822
state = {}
28232823

28242824
# Get (local) EnvRunner state (w/o RLModule).
2825-
if self._check_component(COMPONENT_ENV_RUNNER, components, not_components):
2826-
if self.env_runner:
2827-
state[COMPONENT_ENV_RUNNER] = self.env_runner.get_state(
2828-
components=self._get_subcomponents(COMPONENT_RL_MODULE, components),
2829-
not_components=force_list(
2830-
self._get_subcomponents(COMPONENT_RL_MODULE, not_components)
2825+
if self.config.is_online:
2826+
if self._check_component(COMPONENT_ENV_RUNNER, components, not_components):
2827+
if self.env_runner:
2828+
state[COMPONENT_ENV_RUNNER] = self.env_runner.get_state(
2829+
components=self._get_subcomponents(
2830+
COMPONENT_RL_MODULE, components
2831+
),
2832+
not_components=force_list(
2833+
self._get_subcomponents(COMPONENT_RL_MODULE, not_components)
2834+
)
2835+
# We don't want the RLModule state from the EnvRunners (it's
2836+
# `inference_only` anyway and already provided in full by the
2837+
# Learners).
2838+
+ [COMPONENT_RL_MODULE],
2839+
**kwargs,
28312840
)
2832-
# We don't want the RLModule state from the EnvRunners (it's
2833-
# `inference_only` anyway and already provided in full by the
2834-
# Learners).
2835-
+ [COMPONENT_RL_MODULE],
2836-
**kwargs,
2837-
)
2838-
else:
2839-
state[COMPONENT_ENV_RUNNER] = {
2840-
COMPONENT_ENV_TO_MODULE_CONNECTOR: (
2841-
self.env_to_module_connector.get_state()
2842-
),
2843-
COMPONENT_MODULE_TO_ENV_CONNECTOR: (
2844-
self.module_to_env_connector.get_state()
2845-
),
2846-
}
2841+
else:
2842+
state[COMPONENT_ENV_RUNNER] = {
2843+
COMPONENT_ENV_TO_MODULE_CONNECTOR: (
2844+
self.env_to_module_connector.get_state()
2845+
),
2846+
COMPONENT_MODULE_TO_ENV_CONNECTOR: (
2847+
self.module_to_env_connector.get_state()
2848+
),
2849+
}
28472850

28482851
# Get (local) evaluation EnvRunner state (w/o RLModule).
28492852
if self.eval_env_runner and self._check_component(
@@ -2936,7 +2939,7 @@ def get_checkpointable_components(self) -> List[Tuple[str, "Checkpointable"]]:
29362939
components = [
29372940
(COMPONENT_LEARNER_GROUP, self.learner_group),
29382941
]
2939-
if not self.config.is_offline and self.env_runner:
2942+
if self.config.is_online:
29402943
components.append(
29412944
(COMPONENT_ENV_RUNNER, self.env_runner),
29422945
)

rllib/algorithms/algorithm_config.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,9 @@
8888
from ray.rllib.algorithms.algorithm import Algorithm
8989
from ray.rllib.connectors.connector_v2 import ConnectorV2
9090
from ray.rllib.core.learner import Learner
91+
from ray.rllib.core.learner.differentiable_learner import DifferentiableLearner
9192
from ray.rllib.core.learner.learner_group import LearnerGroup
93+
from ray.rllib.core.learner.torch.torch_meta_learner import TorchMetaLearner
9294
from ray.rllib.core.rl_module.rl_module import RLModule
9395
from ray.rllib.utils.typing import EpisodeType
9496

@@ -358,6 +360,7 @@ def __init__(self, algo_class: Optional[type] = None):
358360
self.update_worker_filter_stats = True
359361
self.use_worker_filter_stats = True
360362
self.sampler_perf_stats_ema_coef = None
363+
self._is_online = True
361364

362365
# `self.learners()`
363366
self.num_learners = 0
@@ -5157,9 +5160,13 @@ def _validate_offline_settings(self):
51575160
# action and observation spaces. Note, we require here the spaces,
51585161
# i.e. a user cannot provide an environment instead because we do
51595162
# not want to create the environment to receive spaces.
5160-
if self.is_offline and (
5161-
not (self.evaluation_num_env_runners > 0 or self.evaluation_interval)
5162-
and (self.action_space is None or self.observation_space is None)
5163+
if (
5164+
self.is_offline
5165+
and not self.is_online
5166+
and (
5167+
not (self.evaluation_num_env_runners > 0 or self.evaluation_interval)
5168+
and (self.action_space is None or self.observation_space is None)
5169+
)
51635170
):
51645171
self._value_error(
51655172
"If no evaluation should be run, `action_space` and "
@@ -5228,6 +5235,14 @@ def _validate_offline_settings(self):
52285235
"recorded episodes cannot be read in for training."
52295236
)
52305237

5238+
@property
5239+
def is_online(self) -> bool:
5240+
"""Defines if this config is for online RL.
5241+
5242+
Note, a config can be for on- and offline training at the same time.
5243+
"""
5244+
return self._is_online
5245+
52315246
@property
52325247
def is_offline(self) -> bool:
52335248
"""Defines, if this config is for offline RL."""
@@ -6045,6 +6060,7 @@ def learners(
60456060
self,
60466061
*,
60476062
differentiable_learner_configs: List[DifferentiableLearnerConfig] = NotProvided,
6063+
**kwargs,
60486064
) -> "DifferentiableAlgorithmConfig":
60496065
"""Sets the configurations for differentiable learners.
60506066
@@ -6053,6 +6069,8 @@ def learners(
60536069
defining the `DifferentiableLearner` classes used for the nested updates in
60546070
`Algorithm`'s learner.
60556071
"""
6072+
super().learners(**kwargs)
6073+
60566074
if differentiable_learner_configs is not NotProvided:
60576075
self.differentiable_learner_configs = differentiable_learner_configs
60586076

@@ -6092,18 +6110,40 @@ def validate(self):
60926110
"one instance is not a `DifferentiableLearnerConfig`."
60936111
)
60946112

6095-
def get_default_learner_class(self):
6096-
"""Returns the Learner class to use for this algorithm.
6113+
def get_default_learner_class(self) -> Union[Type["TorchMetaLearner"], str]:
6114+
"""Returns the `MetaLearner` class to use for this algorithm.
60976115
60986116
Override this method in the sub-class to return the `MetaLearner`.
60996117
61006118
Returns:
61016119
The `MetaLearner` class to use for this algorithm either as a class
61026120
type or as a string. (e.g. "ray.rllib.core.learner.torch.torch_meta_learner.TorchMetaLearner")
61036121
"""
6104-
from ray.rllib.core.learner.torch.torch_meta_learner import TorchMetaLearner
6122+
return NotImplemented
61056123

6106-
return TorchMetaLearner
6124+
def get_differentiable_learner_classes(
6125+
self,
6126+
) -> List[Union[Type["DifferentiableLearner"], str]]:
6127+
"""Returns the `DifferentiableLearner` classes to use for this algorithm.
6128+
6129+
Override this method in the sub-class to return the `DifferentiableLearner`.
6130+
6131+
Returns:
6132+
The `DifferentiableLearner` class to use for this algorithm either as a class
6133+
type or as a string. (e.g.
6134+
"ray.rllib.core.learner.torch.torch_meta_learner.TorchDifferentiableLearner").
6135+
"""
6136+
return NotImplemented
6137+
6138+
def get_differentiable_learner_configs(self) -> List[DifferentiableLearnerConfig]:
6139+
"""Returns the `DifferentiableLearnerConfigs` for all `DifferentiableLearner`s.
6140+
6141+
Override this method in the sub-class to return the `DifferentiableLearnerConfig`s.
6142+
6143+
Returns:
6144+
The `DifferentiableLearnerConfig` instances to use for this algorithm.
6145+
"""
6146+
return self.differentiable_learner_configs
61076147

61086148

61096149
class TorchCompileWhatToCompile(str, Enum):

rllib/algorithms/marwil/marwil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def __init__(self, algo_class=None):
166166
}
167167

168168
super().__init__(algo_class=algo_class or MARWIL)
169-
169+
self._is_online = False
170170
# fmt: off
171171
# __sphinx_doc_begin__
172172
# MARWIL specific settings:

rllib/core/learner/differentiable_learner.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def build(self) -> None:
123123

124124
# TODO (simon): Move the `build_learner_connector` to the
125125
# `DifferentiableLearnerConfig`.
126-
self._learner_connector = self.config.build_learner_connector(
126+
self._learner_connector = self.learner_config.build_learner_connector(
127127
input_observation_space=None,
128128
input_action_space=None,
129129
device=None,
@@ -383,14 +383,13 @@ def update(
383383
# gradient steps inside the iterator loop above (could be a complete epoch)
384384
# the target networks might need to be updated earlier.
385385
# self.after_gradient_based_update(timesteps=timesteps or {})
386-
387386
self.metrics.deactivate_tensor_mode()
388387

389388
# Reduce results across all minibatch update steps.
390389
if not _no_metrics_reduce:
391390
return params, loss_per_module, self.metrics.reduce()
392391
else:
393-
return params, loss_per_module, None
392+
return params, loss_per_module, {}
394393

395394
def _create_iterator_if_necessary(
396395
self,
@@ -679,6 +678,26 @@ def _check_is_built(self, error: bool = True) -> bool:
679678
return False
680679
return True
681680

681+
@abc.abstractmethod
682+
def _get_tensor_variable(
683+
self,
684+
value: Any,
685+
dtype: Any = None,
686+
trainable: bool = False,
687+
) -> TensorType:
688+
"""Returns a framework-specific tensor variable with the initial given value.
689+
690+
This is a framework specific method that should be implemented by the
691+
framework specific sub-classes.
692+
693+
Args:
694+
value: The initial value for the tensor variable variable.
695+
696+
Returns:
697+
The framework specific tensor variable of the given initial value,
698+
dtype and trainable/requires_grad property.
699+
"""
700+
682701
# TODO (simon): Duplicate in Learner. Move to base class "Learnable".
683702
def _reset(self):
684703
self.metrics = MetricsLogger()

rllib/core/learner/differentiable_learner_config.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from dataclasses import dataclass
2-
from typing import Callable
1+
from dataclasses import dataclass, fields
32

3+
from typing import Callable, List, Optional, Union
4+
5+
from ray.rllib.connectors.connector_v2 import ConnectorV2
46
from ray.rllib.core.learner.differentiable_learner import DifferentiableLearner
7+
from ray.rllib.core.rl_module.rl_module import RLModule
8+
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
9+
from ray.rllib.utils.typing import ModuleID
510

611

712
@dataclass
@@ -13,6 +18,16 @@ class DifferentiableLearnerConfig:
1318
# The `DifferentiableLearner` class. Must be derived from `DifferentiableLearner`.
1419
learner_class: Callable
1520

21+
learner_connector: Optional[
22+
Callable[["RLModule"], Union["ConnectorV2", List["ConnectorV2"]]]
23+
] = None
24+
25+
add_default_connectors_to_learner_pipeline: bool = True
26+
27+
is_multi_agent: bool = False
28+
29+
policies_to_update: List[ModuleID] = None
30+
1631
# The learning rate to use for the nested update. Note, in the default case this
1732
# learning rate is only used to update parameters in a functional form, i.e. the
1833
# `RLModule`'s stateful parameters are only updated in the `MetaLearner`. Different
@@ -45,3 +60,88 @@ def __post_init__(self):
4560
"`learner_class` must be a subclass of `DifferentiableLearner "
4661
f"but is {self.learner_class}."
4762
)
63+
64+
def build_learner_connector(
65+
self,
66+
input_observation_space,
67+
input_action_space,
68+
device=None,
69+
):
70+
from ray.rllib.connectors.learner import (
71+
AddColumnsFromEpisodesToTrainBatch,
72+
AddObservationsFromEpisodesToBatch,
73+
AddStatesFromEpisodesToBatch,
74+
AddTimeDimToBatchAndZeroPad,
75+
AgentToModuleMapping,
76+
BatchIndividualItems,
77+
LearnerConnectorPipeline,
78+
NumpyToTensor,
79+
)
80+
81+
custom_connectors = []
82+
# Create a learner connector pipeline (including RLlib's default
83+
# learner connector piece) and return it.
84+
if self.learner_connector is not None:
85+
val_ = self.learner_connector(
86+
input_observation_space,
87+
input_action_space,
88+
# device, # TODO (sven): Also pass device into custom builder.
89+
)
90+
91+
from ray.rllib.connectors.connector_v2 import ConnectorV2
92+
93+
# ConnectorV2 (piece or pipeline).
94+
if isinstance(val_, ConnectorV2):
95+
custom_connectors = [val_]
96+
# Sequence of individual ConnectorV2 pieces.
97+
elif isinstance(val_, (list, tuple)):
98+
custom_connectors = list(val_)
99+
# Unsupported return value.
100+
else:
101+
raise ValueError(
102+
"`AlgorithmConfig.training(learner_connector=..)` must return "
103+
"a ConnectorV2 object or a list thereof (to be added to a "
104+
f"pipeline)! Your function returned {val_}."
105+
)
106+
107+
pipeline = LearnerConnectorPipeline(
108+
connectors=custom_connectors,
109+
input_observation_space=input_observation_space,
110+
input_action_space=input_action_space,
111+
)
112+
if self.add_default_connectors_to_learner_pipeline:
113+
# Append OBS handling.
114+
pipeline.append(
115+
AddObservationsFromEpisodesToBatch(as_learner_connector=True)
116+
)
117+
# Append all other columns handling.
118+
pipeline.append(AddColumnsFromEpisodesToTrainBatch())
119+
# Append time-rank handler.
120+
pipeline.append(AddTimeDimToBatchAndZeroPad(as_learner_connector=True))
121+
# Append STATE_IN/STATE_OUT handler.
122+
pipeline.append(AddStatesFromEpisodesToBatch(as_learner_connector=True))
123+
# If multi-agent -> Map from AgentID-based data to ModuleID based data.
124+
if self.is_multi_agent:
125+
pipeline.append(
126+
AgentToModuleMapping(
127+
rl_module_specs=(
128+
self.rl_module_spec.rl_module_specs
129+
if isinstance(self.rl_module_spec, MultiRLModuleSpec)
130+
else set(self.policies)
131+
),
132+
agent_to_module_mapping_fn=self.policy_mapping_fn,
133+
)
134+
)
135+
# Batch all data.
136+
pipeline.append(BatchIndividualItems(multi_agent=self.is_multi_agent))
137+
# Convert to Tensors.
138+
pipeline.append(NumpyToTensor(as_learner_connector=True, device=device))
139+
return pipeline
140+
141+
def update_from_kwargs(self, **kwargs):
142+
"""Sets all slots with values defined in `kwargs`."""
143+
# Get all field names (i.e., slot names).
144+
field_names = {f.name for f in fields(self)}
145+
for key, value in kwargs.items():
146+
if key in field_names:
147+
setattr(self, key, value)

0 commit comments

Comments
 (0)