Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 186 additions & 71 deletions alf/algorithms/data_transformer.py

Large diffs are not rendered by default.

24 changes: 19 additions & 5 deletions alf/algorithms/ddpg_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@
DdpgActorState = namedtuple("DdpgActorState", ['actor', 'critics'])
DdpgState = namedtuple("DdpgState", ['actor', 'critics'])
DdpgInfo = namedtuple(
"DdpgInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor_loss", "critic", "discounted_return"
"DdpgInfo",
[
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor_loss",
"critic",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())
DdpgLossInfo = namedtuple('DdpgLossInfo', ('actor', 'critic'))
Expand Down Expand Up @@ -237,10 +248,12 @@ def _sample(a, ou):
noisy_action, self._action_spec)
state = empty_state._replace(
actor=DdpgActorState(actor=state, critics=()))
# action_distribution is not supported for continuous actions for now.
# Returns empty action_distribution to fail early.
return AlgStep(
output=noisy_action,
state=state,
info=DdpgInfo(action=noisy_action, action_distribution=action))
info=DdpgInfo(action=noisy_action, action_distribution=()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this change? By default we could think of a deterministic action distribution as an action tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fail early and clearly. action is a tensor, not distribution. Putting action directly there could cause confusion when debugging. See comment 3 lines above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is intended to return action_distribution here so that some other algorithm can use it (e.g. TracAlgorithm). Do you find it causing problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when used with e.g. Retrace, a distribution is needed, but action is not a distribution.


def rollout_step(self, time_step: TimeStep, state=None):
if self.need_full_rollout_state():
Expand Down Expand Up @@ -330,7 +343,8 @@ def train_step(self, inputs: TimeStep, state: DdpgState,
reward=inputs.reward,
step_type=inputs.step_type,
discount=inputs.discount,
action_distribution=policy_step.output,
action=policy_step.output,
action_distribution=(),
critic=critic_info,
actor_loss=policy_step.info,
discounted_return=rollout_info.discounted_return))
Expand Down
128 changes: 128 additions & 0 deletions alf/algorithms/dqn_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2020 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DQN Algorithm."""

import torch

import alf
from alf.algorithms.config import TrainerConfig
from alf.algorithms.sac_algorithm import SacAlgorithm, ActionType, \
SacState as DqnState, SacCriticState as DqnCriticState, \
SacInfo as DqnInfo
from alf.data_structures import TimeStep
from alf.networks import QNetwork
from alf.optimizers import AdamTF
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils.schedulers import as_scheduler


@alf.configurable
class DqnAlgorithm(SacAlgorithm):
r"""DQN/DDQN algorithm:

::

Mnih et al "Playing Atari with Deep Reinforcement Learning", arXiv:1312.5602
Hasselt et al "Deep Reinforcement Learning with Double Q-learning", arXiv:1509.06461

The difference with DDQN is that a minimum is taken from the two critics,
similar to TD3, instead of using one critic as the target of the other.

The implementation is based on the SAC algorithm.
"""

def __init__(self,
observation_spec,
action_spec: BoundedTensorSpec,
reward_spec=TensorSpec(()),
q_network_cls=QNetwork,
q_optimizer=None,
rollout_epsilon_greedy=0.1,
num_critic_replicas=2,
env=None,
config: TrainerConfig = None,
critic_loss_ctor=None,
debug_summaries=False,
name="DqnAlgorithm"):
"""
Args:
observation_spec (nested TensorSpec): representing the observations.
action_spec (nested BoundedTensorSpec): representing the actions; can
be a mixture of discrete and continuous actions. The number of
continuous actions can be arbitrary while only one discrete
action is allowed currently. If it's a mixture, then it must be
a tuple/list ``(discrete_action_spec, continuous_action_spec)``.
reward_spec (TensorSpec): a rank-1 or rank-0 tensor spec representing
the reward(s).
q_network (Callable): is used to construct QNetwork for estimating ``Q(s,a)``
given that the action is discrete. Its output spec must be consistent with
the discrete action in ``action_spec``.
q_optimizer (torch.optim.optimizer): A custom optimizer for the q network.
Uses the enclosing algorithm's optimizer if None.
rollout_epsilon_greedy (float|Scheduler): epsilon greedy policy for rollout.
Together with the following two parameters, the SAC algorithm
can be converted to a DQN or DDQN algorithm when e.g.
``rollout_epsilon_greedy=0.3``, ``max_target_action=True``, and
``use_entropy_reward=False``.
num_critic_replicas (int): number of critics to be used. Default is 2.
env (Environment): The environment to interact with. ``env`` is a
batched environment, which means that it runs multiple simulations
simultateously. ``env` only needs to be provided to the root
algorithm.
config (TrainerConfig): config for training. It only needs to be
provided to the algorithm which performs ``train_iter()`` by
itself.
critic_loss_ctor (None|OneStepTDLoss|MultiStepLoss): a critic loss
constructor. If ``None``, a default ``OneStepTDLoss`` will be used.
debug_summaries (bool): True if debug summaries should be created.
name (str): The name of this algorithm.
"""
self._rollout_epsilon_greedy = as_scheduler(rollout_epsilon_greedy)
# Disable alpha learning:
alpha_optimizer = AdamTF(lr=0)

super().__init__(
observation_spec=observation_spec,
action_spec=action_spec,
reward_spec=reward_spec,
actor_network_cls=None,
critic_network_cls=None,
q_network_cls=q_network_cls,
# Do not use entropy reward:
use_entropy_reward=False,
num_critic_replicas=num_critic_replicas,
env=env,
config=config,
critic_loss_ctor=critic_loss_ctor,
# Allow custom optimizer for q_network:
critic_optimizer=q_optimizer,
alpha_optimizer=alpha_optimizer,
debug_summaries=debug_summaries,
name=name)
assert self._act_type == ActionType.Discrete

def rollout_step(self, inputs: TimeStep, state: DqnState):
return super().rollout_step(
inputs, state, eps=self._rollout_epsilon_greedy())

def _critic_train_step(self, inputs: TimeStep, state: DqnCriticState,
rollout_info: DqnInfo, action, action_distribution):
return super()._critic_train_step(
inputs,
state,
rollout_info,
action,
action_distribution,
# Pick the greedy target action:
target_action_picker=lambda t: torch.max(t, dim=1)[0])
4 changes: 2 additions & 2 deletions alf/algorithms/one_step_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from typing import Union, List, Callable

import alf
from alf.algorithms.td_loss import TDLoss, TDQRLoss
from alf.algorithms.td_loss import LowerBoundedTDLoss, TDQRLoss
from alf.utils import losses


@alf.configurable
class OneStepTDLoss(TDLoss):
class OneStepTDLoss(LowerBoundedTDLoss):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conceptually OneStepTDLoss is not a child (special case) of LowerBoundedTDLoss. For me, LowerBoundedTDLoss is far more special and should be a completely new class, or a child of OneStepTDLoss.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LowerBoundedTDLoss defaults to TDLoss and can be configured to enable lower bounding, so it's more general than TDLoss. Maybe I should name it something else?

def __init__(self,
gamma: Union[float, List[float]] = 0.99,
td_error_loss_fn: Callable = losses.element_wise_squared_loss,
Expand Down
6 changes: 5 additions & 1 deletion alf/algorithms/rl_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,19 +223,23 @@ def __init__(self,
replay_buffer_length = adjust_replay_buffer_length(
config, self._num_earliest_frames_ignored)

total_replay_size = replay_buffer_length * self._env.batch_size
if config.whole_replay_buffer_training and config.clear_replay_buffer:
# For whole replay buffer training, we would like to be sure
# that the replay buffer have enough samples in it to perform
# the training, which will most likely happen in the 2nd
# iteration. The minimum_initial_collect_steps guarantees that.
minimum_initial_collect_steps = replay_buffer_length * self._env.batch_size
minimum_initial_collect_steps = total_replay_size
if config.initial_collect_steps < minimum_initial_collect_steps:
common.info(
'Set the initial_collect_steps to minimum required '
f'value {minimum_initial_collect_steps} because '
'whole_replay_buffer_training is on.')
config.initial_collect_steps = minimum_initial_collect_steps

assert config.initial_collect_steps <= total_replay_size, \
"Training will not happen - insufficient replay buffer size."

self.set_replay_buffer(self._env.batch_size, replay_buffer_length,
config.priority_replay)

Expand Down
38 changes: 30 additions & 8 deletions alf/algorithms/sac_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from alf.tensor_specs import TensorSpec, BoundedTensorSpec
from alf.utils import losses, common, dist_utils, math_ops
from alf.utils.normalizers import ScalarAdaptiveNormalizer
from alf.utils.schedulers import as_scheduler

ActionType = Enum('ActionType', ('Discrete', 'Continuous', 'Mixed'))

Expand All @@ -54,9 +55,22 @@
"SacActorInfo", ["actor_loss", "neg_entropy"], default_value=())

SacInfo = namedtuple(
"SacInfo", [
"reward", "step_type", "discount", "action", "action_distribution",
"actor", "critic", "alpha", "log_pi", "discounted_return"
"SacInfo",
[
"reward",
"step_type",
"discount",
"action",
"action_distribution",
"actor",
"critic",
"alpha",
"log_pi",
# Optional fields for value target lower bounding or Hindsight relabeling.
# TODO: Extract these into a HerAlgorithm wrapper for easier adoption of HER.
"discounted_return",
"future_distance",
"her"
],
default_value=())

Expand Down Expand Up @@ -541,7 +555,7 @@ def predict_step(self, inputs: TimeStep, state: SacState):
state=SacState(action=action_state),
info=SacInfo(action_distribution=action_dist))

def rollout_step(self, inputs: TimeStep, state: SacState):
def rollout_step(self, inputs: TimeStep, state: SacState, eps: float = 1.):
"""``rollout_step()`` basically predicts actions like what is done by
``predict_step()``. Additionally, if states are to be stored a in replay
buffer, then this function also call ``_critic_networks`` and
Expand All @@ -550,7 +564,7 @@ def rollout_step(self, inputs: TimeStep, state: SacState):
action_dist, action, _, action_state = self._predict_action(
inputs.observation,
state=state.action,
epsilon_greedy=1.0,
epsilon_greedy=eps,
eps_greedy_sampling=True,
rollout=True)

Expand Down Expand Up @@ -694,8 +708,13 @@ def _select_q_value(self, action, q_values):
*self._reward_spec.shape).long()
return q_values.gather(2, action).squeeze(2)

def _critic_train_step(self, inputs: TimeStep, state: SacCriticState,
rollout_info: SacInfo, action, action_distribution):
def _critic_train_step(self,
inputs: TimeStep,
state: SacCriticState,
rollout_info: SacInfo,
action,
action_distribution,
target_action_picker: Callable = None):
critics, critics_state = self._compute_critics(
self._critic_networks,
inputs.observation,
Expand All @@ -717,7 +736,10 @@ def _critic_train_step(self, inputs: TimeStep, state: SacCriticState,
probs = common.expand_dims_as(action_distribution.probs,
target_critics)
# [B, reward_dim]
target_critics = torch.sum(probs * target_critics, dim=1)
if target_action_picker is not None:
target_critics = target_action_picker(target_critics)
else:
target_critics = torch.sum(probs * target_critics, dim=1)
elif self._act_type == ActionType.Mixed:
critics = self._select_q_value(rollout_info.action[0], critics)
discrete_act_dist = action_distribution[0]
Expand Down
Loading