From c19dcbfd4193872b94717cb6422fc0ae339cf863 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Tue, 20 Feb 2024 08:59:43 +0000 Subject: [PATCH] feat: add qr-dqn --- stoix/configs/default_ff_qr_dqn.yaml | 7 + stoix/configs/network/mlp_qr_dqn.yaml | 9 + stoix/configs/system/ff_qr_dqn.yaml | 24 ++ stoix/networks/heads.py | 17 + stoix/systems/q_learning/ff_qr_dqn.py | 554 ++++++++++++++++++++++++++ stoix/utils/loss.py | 95 +++++ 6 files changed, 706 insertions(+) create mode 100644 stoix/configs/default_ff_qr_dqn.yaml create mode 100644 stoix/configs/network/mlp_qr_dqn.yaml create mode 100644 stoix/configs/system/ff_qr_dqn.yaml create mode 100644 stoix/systems/q_learning/ff_qr_dqn.py diff --git a/stoix/configs/default_ff_qr_dqn.yaml b/stoix/configs/default_ff_qr_dqn.yaml new file mode 100644 index 00000000..57757e67 --- /dev/null +++ b/stoix/configs/default_ff_qr_dqn.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_dqn + - arch: anakin + - system: ff_qr_dqn + - network: mlp_qr_dqn + - env: gymnax/cartpole + - _self_ diff --git a/stoix/configs/network/mlp_qr_dqn.yaml b/stoix/configs/network/mlp_qr_dqn.yaml new file mode 100644 index 00000000..f4006a41 --- /dev/null +++ b/stoix/configs/network/mlp_qr_dqn.yaml @@ -0,0 +1,9 @@ +# ---MLP DQN Networks--- +actor_network: + pre_torso: + _target_: stoix.networks.torso.MLPTorso + layer_sizes: [256, 256] + use_layer_norm: True + activation: silu + action_head: + _target_: stoix.networks.heads.QuantileDiscreteQNetwork diff --git a/stoix/configs/system/ff_qr_dqn.yaml b/stoix/configs/system/ff_qr_dqn.yaml new file mode 100644 index 00000000..cccda782 --- /dev/null +++ b/stoix/configs/system/ff_qr_dqn.yaml @@ -0,0 +1,24 @@ +# --- Defaults FF-DQN --- + +total_timesteps: 1e8 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates +seed: 42 + +# --- RL hyperparameters --- +update_batch_size: 1 # Number of vectorised gradient updates per device. +rollout_length: 8 # Number of environment steps per vectorised environment. +epochs: 16 # Number of sgd steps per rollout. +warmup_steps: 128 # Number of steps to collect before training. +buffer_size: 100_000 # size of the replay buffer. +batch_size: 128 # Number of samples to train on per device. +q_lr: 1e-5 # the learning rate of the Q network network optimizer +tau: 0.005 # smoothing coefficient for target networks +gamma: 0.99 # discount factor +max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. +decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training +evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation +max_abs_reward : 1000.0 # maximum absolute reward value +huber_loss_parameter: 1.0 # parameter for the huber loss +num_quantiles: 200 # number of quantiles diff --git a/stoix/networks/heads.py b/stoix/networks/heads.py index d5eb4a87..da0dfcfe 100644 --- a/stoix/networks/heads.py +++ b/stoix/networks/heads.py @@ -218,3 +218,20 @@ def __call__( q_value = jnp.sum(q_dist * atoms, axis=-1) atoms = jnp.broadcast_to(atoms, (*q_value.shape, self.num_atoms)) return q_value, q_logits, atoms + + +class QuantileDiscreteQNetwork(nn.Module): + action_dim: int + epsilon: float + num_quantiles: int + kernel_init: Initializer = lecun_normal() + + @nn.compact + def __call__(self, embedding: chex.Array) -> Tuple[distrax.EpsilonGreedy, chex.Array]: + q_logits = nn.Dense(self.action_dim * self.num_quantiles, kernel_init=self.kernel_init)( + embedding + ) + q_dist = jnp.reshape(q_logits, (-1, self.action_dim, self.num_quantiles)) + q_values = jnp.mean(q_dist, axis=-1) + q_values = jax.lax.stop_gradient(q_values) + return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_dist diff --git a/stoix/systems/q_learning/ff_qr_dqn.py b/stoix/systems/q_learning/ff_qr_dqn.py new file mode 100644 index 00000000..1696cfbe --- /dev/null +++ b/stoix/systems/q_learning/ff_qr_dqn.py @@ -0,0 +1,554 @@ +import copy +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple + +from stoix.systems.q_learning.types import DQNLearnerState, QsAndTarget, Transition +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.loss import quantile_q_learning +from stoix.utils.training import make_learning_rate + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass + +import chex +import distrax +import flashbax as fbx +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flashbax.buffers.trajectory_buffer import BufferState +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from jumanji.types import TimeStep +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.evaluator import evaluator_setup +from stoix.networks.base import FeedForwardActor as Actor +from stoix.types import ( + ActorApply, + ExperimentOutput, + LearnerFn, + LogEnvState, + Observation, +) +from stoix.utils import make_env as environments +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.total_timestep_checker import check_total_timesteps + + +def get_warmup_fn( + env: Environment, + q_params: FrozenDict, + q_apply_fn: ActorApply, + buffer_add_fn: Callable, + config: DictConfig, +) -> Callable: + def warmup( + env_states: LogEnvState, timesteps: TimeStep, buffer_states: BufferState, keys: chex.PRNGKey + ) -> Tuple[LogEnvState, TimeStep, BufferState, chex.PRNGKey]: + def _env_step( + carry: Tuple[LogEnvState, TimeStep, chex.PRNGKey], _: Any + ) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], Transition]: + """Step the environment.""" + + env_state, last_timestep, key = carry + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy, _ = q_apply_fn(q_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + transition = Transition( + last_timestep.observation, action, timestep.reward, done, timestep.observation, info + ) + + return (env_state, timestep, key), transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + (env_states, timesteps, keys), traj_batch = jax.lax.scan( + _env_step, (env_states, timesteps, keys), None, config.system.warmup_steps + ) + + # Add the trajectory to the buffer. + buffer_states = buffer_add_fn(buffer_states, traj_batch) + + return env_states, timesteps, keys, buffer_states + + batched_warmup_step: Callable = jax.vmap( + warmup, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0), axis_name="batch" + ) + + return batched_warmup_step + + +def get_learner_fn( + env: Environment, + q_apply_fn: ActorApply, + q_update_fn: optax.TransformUpdateFn, + buffer_fns: Tuple[Callable, Callable], + config: DictConfig, +) -> LearnerFn[DQNLearnerState]: + """Get the learner function.""" + + buffer_add_fn, buffer_sample_fn = buffer_fns + + def _update_step(learner_state: DQNLearnerState, _: Any) -> Tuple[DQNLearnerState, Tuple]: + def _env_step(learner_state: DQNLearnerState, _: Any) -> Tuple[DQNLearnerState, Transition]: + """Step the environment.""" + q_params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy, _ = q_apply_fn(q_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + transition = Transition( + last_timestep.observation, action, timestep.reward, done, timestep.observation, info + ) + + learner_state = DQNLearnerState( + q_params, opt_states, buffer_state, key, env_state, timestep + ) + return learner_state, transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # Add the trajectory to the buffer. + buffer_state = buffer_add_fn(buffer_state, traj_batch) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _q_loss_fn( + q_params: FrozenDict, + target_q_params: FrozenDict, + transitions: Transition, + ) -> jnp.ndarray: + + _, q_dist_tm1 = q_apply_fn(q_params, transitions.obs) + _, q_dist_t = q_apply_fn(target_q_params, transitions.next_obs) + + # Cast and clip rewards. + discount = 1.0 - transitions.done.astype(jnp.float32) + d_t = (discount * config.system.gamma).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -config.system.max_abs_reward, config.system.max_abs_reward + ).astype(jnp.float32) + a_tm1 = transitions.action + + # Swap distribution and action dimension, since + # quantile_q_learning expects it that way. + q_dist_tm1 = jnp.swapaxes(q_dist_tm1, 1, 2) + q_dist_t = jnp.swapaxes(q_dist_t, 1, 2) + quantiles = ( + jnp.arange(config.system.num_quantiles, dtype=jnp.float32) + 0.5 + ) / config.system.num_quantiles + quantiles = jnp.broadcast_to( + quantiles, (a_tm1.shape[0], config.system.num_quantiles) + ) + + q_loss = quantile_q_learning( + q_dist_tm1, + quantiles, + a_tm1, + r_t, + d_t, + q_dist_t, # No double Q-learning here. + q_dist_t, + config.system.huber_loss_parameter, + ) + + loss_info = { + "q_loss": q_loss, + } + + return q_loss, loss_info + + params, opt_states, buffer_state, key = update_state + + key, sample_key = jax.random.split(key) + + # SAMPLE TRANSITIONS + transition_sample = buffer_sample_fn(buffer_state, sample_key) + transitions: Transition = transition_sample.experience + + # CALCULATE Q LOSS + q_grad_fn = jax.grad(_q_loss_fn, has_aux=True) + q_grads, q_loss_info = q_grad_fn( + params.online, + params.target, + transitions, + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="batch") + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="device") + + # UPDATE Q PARAMS AND OPTIMISER STATE + q_updates, q_new_opt_state = q_update_fn(q_grads, opt_states) + q_new_online_params = optax.apply_updates(params.online, q_updates) + # Target network polyak update. + new_target_q_params = optax.incremental_update( + q_new_online_params, params.target, config.system.tau + ) + q_new_params = QsAndTarget(q_new_online_params, new_target_q_params) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = q_new_params + new_opt_state = q_new_opt_state + + # PACK LOSS INFO + loss_info = { + "total_loss": q_loss_info["q_loss"], + } + return (new_params, new_opt_state, buffer_state, key), loss_info + + update_state = (params, opt_states, buffer_state, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, buffer_state, key = update_state + learner_state = DQNLearnerState( + params, opt_states, buffer_state, key, env_state, last_timestep + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: DQNLearnerState) -> ExperimentOutput[DQNLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +@dataclass +class EvalActorWrapper: + actor: Actor + + def apply(self, params: FrozenDict, x: Observation) -> distrax.EpsilonGreedy: + return self.actor.apply(params, x)[0] + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[DQNLearnerState], EvalActorWrapper, DQNLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions. + action_dim = int(env.action_spec().num_values) + config.system.action_dim = action_dim + + # PRNG keys. + key, q_net_key = keys + + # Define actor_network and optimiser. + q_network_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + q_network_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, + action_dim=action_dim, + epsilon=config.system.training_epsilon, + num_quantiles=config.system.num_quantiles, + ) + + q_network = Actor(torso=q_network_torso, action_head=q_network_action_head) + + eval_q_network_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, + action_dim=action_dim, + epsilon=config.system.evaluation_epsilon, + num_quantiles=config.system.num_quantiles, + ) + eval_q_network = Actor(torso=q_network_torso, action_head=eval_q_network_action_head) + eval_q_network = EvalActorWrapper(actor=eval_q_network) + + q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs) + q_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(q_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_x = env.observation_spec().generate_value() + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + + # Initialise q params and optimiser state. + q_online_params = q_network.init(q_net_key, init_x) + q_target_params = q_online_params + q_opt_state = q_optim.init(q_online_params) + + params = QsAndTarget(q_online_params, q_target_params) + opt_states = q_opt_state + + vmapped_q_network_apply_fn = q_network.apply + + # Pack apply and update functions. + apply_fns = vmapped_q_network_apply_fn + update_fns = q_optim.update + + # Create replay buffer + dummy_transition = Transition( + obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + action=jnp.zeros((), dtype=int), + reward=jnp.zeros((), dtype=float), + done=jnp.zeros((), dtype=bool), + next_obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + info={"episode_return": 0.0, "episode_length": 0}, + ) + + buffer_fn = fbx.make_item_buffer( + max_length=config.system.buffer_size, + min_length=config.system.batch_size, + sample_batch_size=config.system.batch_size, + add_batches=True, + add_sequences=True, + ) + buffer_fns = (buffer_fn.add, buffer_fn.sample) + buffer_states = buffer_fn.init(dummy_transition) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, buffer_fns, config) + learn = jax.pmap(learn, axis_name="device") + + warmup = get_warmup_fn(env, params, vmapped_q_network_apply_fn, buffer_fn.add, config) + warmup = jax.pmap(warmup, axis_name="device") + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(TParams=QsAndTarget) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys, warmup_keys = jax.random.split(key, num=3) + + replicate_learner = (params, opt_states, buffer_states, step_keys, warmup_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, buffer_states, step_keys, warmup_keys = replicate_learner + # Warmup the buffer. + env_states, timesteps, keys, buffer_states = warmup( + env_states, timesteps, buffer_states, warmup_keys + ) + init_learner_state = DQNLearnerState( + params, opt_states, buffer_states, step_keys, env_states, timesteps + ) + + return learn, eval_q_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config=config) + + # PRNG keys. + key, key_e, q_net_key = jax.random.split(jax.random.PRNGKey(config["system"]["seed"]), num=3) + + # Setup learner. + learn, eval_q_network, learner_state = learner_setup(env, (key, q_net_key), config) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + network=eval_q_network, + params=learner_state.params.online, + config=config, + ) + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(0.0) + best_params = unreplicate_batch_dim(learner_state.params.online) + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + learner_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim( + learner_output.learner_state.params.online + ) # Select only actor params + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + +@hydra.main(config_path="../../configs", config_name="default_ff_qr_dqn.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}QR-DQN experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/utils/loss.py b/stoix/utils/loss.py index d5d566d1..35f2d33a 100644 --- a/stoix/utils/loss.py +++ b/stoix/utils/loss.py @@ -6,6 +6,10 @@ tfd = tfp.distributions +# These losses are generally taken from rlax but edited to explictly take in a batch of data. +# This is because the original rlax losses are not batched and are meant to be used with vmap, +# which is much slower. + def ppo_loss( pi_log_prob_t: chex.Array, b_pi_log_prob_t: chex.Array, gae_t: chex.Array, epsilon: float @@ -152,3 +156,94 @@ def munchausen_q_learning( batch_loss = rlax.huber_loss(target_q - q_tm1_a, huber_loss_parameter) batch_loss = jnp.mean(batch_loss) return batch_loss + + +def quantile_regression_loss( + dist_src: chex.Array, + tau_src: chex.Array, + dist_target: chex.Array, + huber_param: float = 0.0, +) -> chex.Array: + """Compute (Huber) QR loss between two discrete quantile-valued distributions. + + See "Distributional Reinforcement Learning with Quantile Regression" by + Dabney et al. (https://arxiv.org/abs/1710.10044). + + Args: + dist_src: source probability distribution. + tau_src: source distribution probability thresholds. + dist_target: target probability distribution. + huber_param: Huber loss parameter, defaults to 0 (no Huber loss). + stop_target_gradients: bool indicating whether or not to apply stop gradient + to targets. + + Returns: + Quantile regression loss. + """ + + batch_indices = jnp.arange(dist_src.shape[0]) + + # Calculate quantile error. + delta = dist_target[batch_indices, None, :] - dist_src[batch_indices, :, None] + delta_neg = (delta < 0.0).astype(jnp.float32) + delta_neg = jax.lax.stop_gradient(delta_neg) + weight = jnp.abs(tau_src[batch_indices, :, None] - delta_neg) + + # Calculate Huber loss. + if huber_param > 0.0: + loss = rlax.huber_loss(delta, huber_param) + else: + loss = jnp.abs(delta) + loss *= weight + + # Average over target-samples dimension, sum over src-samples dimension. + return jnp.sum(jnp.mean(loss, axis=-1), axis=-1) + + +def quantile_q_learning( + dist_q_tm1: chex.Array, + tau_q_tm1: chex.Array, + a_tm1: chex.Array, + r_t: chex.Array, + d_t: chex.Array, + dist_q_t_selector: chex.Array, + dist_q_t: chex.Array, + huber_param: float = 0.0, +) -> chex.Array: + """Implements Q-learning for quantile-valued Q distributions. + + See "Distributional Reinforcement Learning with Quantile Regression" by + Dabney et al. (https://arxiv.org/abs/1710.10044). + + Args: + dist_q_tm1: Q distribution at time t-1. + tau_q_tm1: Q distribution probability thresholds. + a_tm1: action index at time t-1. + r_t: reward at time t. + d_t: discount at time t. + dist_q_t_selector: Q distribution at time t for selecting greedy action in + target policy. This is separate from dist_q_t as in Double Q-Learning, but + can be computed with the target network and a separate set of samples. + dist_q_t: target Q distribution at time t. + huber_param: Huber loss parameter, defaults to 0 (no Huber loss). + stop_target_gradients: bool indicating whether or not to apply stop gradient + to targets. + + Returns: + Quantile regression Q learning loss. + """ + batch_indices = jnp.arange(a_tm1.shape[0]) + + # Only update the taken actions. + dist_qa_tm1 = dist_q_tm1[batch_indices, :, a_tm1] + + # Select target action according to greedy policy w.r.t. dist_q_t_selector. + q_t_selector = jnp.mean(dist_q_t_selector, axis=1) + a_t = jnp.argmax(q_t_selector, axis=-1) + dist_qa_t = dist_q_t[batch_indices, :, a_t] + + # Compute target, do not backpropagate into it. + dist_target = r_t[:, jnp.newaxis] + d_t[:, jnp.newaxis] * dist_qa_t + dist_target = jax.lax.stop_gradient(dist_target) + + return quantile_regression_loss(dist_qa_tm1, tau_q_tm1, dist_target, huber_param).mean()