From c654572d858c85eeb24d9944bb12b217ca93ea90 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 19 Feb 2024 18:50:43 +0000 Subject: [PATCH] feat: add M-DQN --- stoix/configs/default_ff_mdqn.yaml | 7 + stoix/configs/system/ff_mdqn.yaml | 26 ++ stoix/systems/q_learning/ff_mdqn.py | 523 ++++++++++++++++++++++++++++ stoix/utils/loss.py | 33 ++ 4 files changed, 589 insertions(+) create mode 100644 stoix/configs/default_ff_mdqn.yaml create mode 100644 stoix/configs/system/ff_mdqn.yaml create mode 100644 stoix/systems/q_learning/ff_mdqn.py diff --git a/stoix/configs/default_ff_mdqn.yaml b/stoix/configs/default_ff_mdqn.yaml new file mode 100644 index 00000000..5561ebaa --- /dev/null +++ b/stoix/configs/default_ff_mdqn.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_dqn + - arch: anakin + - system: ff_mdqn + - network: mlp_dqn + - env: gymnax/cartpole + - _self_ diff --git a/stoix/configs/system/ff_mdqn.yaml b/stoix/configs/system/ff_mdqn.yaml new file mode 100644 index 00000000..30d8adc9 --- /dev/null +++ b/stoix/configs/system/ff_mdqn.yaml @@ -0,0 +1,26 @@ +# --- Defaults FF-DQN --- + +total_timesteps: 1e8 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates +seed: 42 + +# --- RL hyperparameters --- +update_batch_size: 1 # Number of vectorised gradient updates per device. +rollout_length: 8 # Number of environment steps per vectorised environment. +epochs: 16 # Number of sgd steps per rollout. +warmup_steps: 128 # Number of steps to collect before training. +buffer_size: 100_000 # size of the replay buffer. +batch_size: 128 # Number of samples to train on per device. +q_lr: 1e-5 # the learning rate of the Q network network optimizer +tau: 0.005 # smoothing coefficient for target networks +gamma: 0.99 # discount factor +max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. +decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training +evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation +max_abs_reward : 1000.0 # maximum absolute reward value +huber_loss_parameter: 1.0 # parameter for the huber loss +entropy_temperature: 0.03 # tau parameter +munchausen_coefficient: 0.9 # alpha parameter +clip_value_min: -1e3 diff --git a/stoix/systems/q_learning/ff_mdqn.py b/stoix/systems/q_learning/ff_mdqn.py new file mode 100644 index 00000000..eb7b24f7 --- /dev/null +++ b/stoix/systems/q_learning/ff_mdqn.py @@ -0,0 +1,523 @@ +import copy +import time +from typing import Any, Callable, Dict, Tuple + +import chex +import flashbax as fbx +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flashbax.buffers.trajectory_buffer import BufferState +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from jumanji.types import TimeStep +from omegaconf import DictConfig, OmegaConf +from rich.pretty import pprint + +from stoix.evaluator import evaluator_setup +from stoix.networks.base import FeedForwardActor as Actor +from stoix.systems.q_learning.types import DQNLearnerState, QsAndTarget, Transition +from stoix.types import ActorApply, ExperimentOutput, LearnerFn, LogEnvState +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax import unreplicate_batch_dim, unreplicate_n_dims +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import munchausen_q_learning +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate + + +def get_warmup_fn( + env: Environment, + q_params: FrozenDict, + q_apply_fn: ActorApply, + buffer_add_fn: Callable, + config: DictConfig, +) -> Callable: + def warmup( + env_states: LogEnvState, timesteps: TimeStep, buffer_states: BufferState, keys: chex.PRNGKey + ) -> Tuple[LogEnvState, TimeStep, BufferState, chex.PRNGKey]: + def _env_step( + carry: Tuple[LogEnvState, TimeStep, chex.PRNGKey], _: Any + ) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], Transition]: + """Step the environment.""" + + env_state, last_timestep, key = carry + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = q_apply_fn(q_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + transition = Transition( + last_timestep.observation, action, timestep.reward, done, timestep.observation, info + ) + + return (env_state, timestep, key), transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + (env_states, timesteps, keys), traj_batch = jax.lax.scan( + _env_step, (env_states, timesteps, keys), None, config.system.warmup_steps + ) + + # Add the trajectory to the buffer. + buffer_states = buffer_add_fn(buffer_states, traj_batch) + + return env_states, timesteps, keys, buffer_states + + batched_warmup_step: Callable = jax.vmap( + warmup, in_axes=(0, 0, 0, 0), out_axes=(0, 0, 0, 0), axis_name="batch" + ) + + return batched_warmup_step + + +def get_learner_fn( + env: Environment, + q_apply_fn: ActorApply, + q_update_fn: optax.TransformUpdateFn, + buffer_fns: Tuple[Callable, Callable], + config: DictConfig, +) -> LearnerFn[DQNLearnerState]: + """Get the learner function.""" + + buffer_add_fn, buffer_sample_fn = buffer_fns + + def _update_step(learner_state: DQNLearnerState, _: Any) -> Tuple[DQNLearnerState, Tuple]: + def _env_step(learner_state: DQNLearnerState, _: Any) -> Tuple[DQNLearnerState, Transition]: + """Step the environment.""" + q_params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = q_apply_fn(q_params.online, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + transition = Transition( + last_timestep.observation, action, timestep.reward, done, timestep.observation, info + ) + + learner_state = DQNLearnerState( + q_params, opt_states, buffer_state, key, env_state, timestep + ) + return learner_state, transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + params, opt_states, buffer_state, key, env_state, last_timestep = learner_state + + # Add the trajectory to the buffer. + buffer_state = buffer_add_fn(buffer_state, traj_batch) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _q_loss_fn( + q_params: FrozenDict, + target_q_params: FrozenDict, + transitions: Transition, + ) -> jnp.ndarray: + + q_tm1 = q_apply_fn(q_params, transitions.obs).preferences + q_tm1_target = q_apply_fn(target_q_params, transitions.obs).preferences + q_t_target = q_apply_fn(target_q_params, transitions.next_obs).preferences + + # Cast and clip rewards. + discount = 1.0 - transitions.done.astype(jnp.float32) + d_t = (discount * config.system.gamma).astype(jnp.float32) + r_t = jnp.clip( + transitions.reward, -config.system.max_abs_reward, config.system.max_abs_reward + ).astype(jnp.float32) + a_tm1 = transitions.action + + batch_loss = munchausen_q_learning( + q_tm1, + q_tm1_target, + a_tm1, + r_t, + d_t, + q_t_target, + config.system.entropy_temperature, + config.system.munchausen_coefficient, + config.system.clip_value_min, + config.system.huber_loss_parameter, + ) + + loss_info = { + "q_loss": jnp.mean(batch_loss), + } + + return jnp.mean(batch_loss), loss_info + + params, opt_states, buffer_state, key = update_state + + key, sample_key = jax.random.split(key) + + # SAMPLE TRANSITIONS + transition_sample = buffer_sample_fn(buffer_state, sample_key) + transitions: Transition = transition_sample.experience + + # CALCULATE Q LOSS + q_grad_fn = jax.grad(_q_loss_fn, has_aux=True) + q_grads, q_loss_info = q_grad_fn( + params.online, + params.target, + transitions, + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="batch") + q_grads, q_loss_info = jax.lax.pmean((q_grads, q_loss_info), axis_name="device") + + # UPDATE Q PARAMS AND OPTIMISER STATE + q_updates, q_new_opt_state = q_update_fn(q_grads, opt_states) + q_new_online_params = optax.apply_updates(params.online, q_updates) + # Target network polyak update. + new_target_q_params = optax.incremental_update( + q_new_online_params, params.target, config.system.tau + ) + q_new_params = QsAndTarget(q_new_online_params, new_target_q_params) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = q_new_params + new_opt_state = q_new_opt_state + + # PACK LOSS INFO + loss_info = { + "total_loss": q_loss_info["q_loss"], + } + return (new_params, new_opt_state, buffer_state, key), loss_info + + update_state = (params, opt_states, buffer_state, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.epochs + ) + + params, opt_states, buffer_state, key = update_state + learner_state = DQNLearnerState( + params, opt_states, buffer_state, key, env_state, last_timestep + ) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: DQNLearnerState) -> ExperimentOutput[DQNLearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[DQNLearnerState], Actor, DQNLearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions. + action_dim = int(env.action_spec().num_values) + config.system.action_dim = action_dim + + # PRNG keys. + key, q_net_key = keys + + # Define networks and optimiser. + q_network_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + q_network_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, + action_dim=action_dim, + epsilon=config.system.training_epsilon, + ) + + q_network = Actor(torso=q_network_torso, action_head=q_network_action_head) + + eval_q_network_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, + action_dim=action_dim, + epsilon=config.system.evaluation_epsilon, + ) + eval_q_network = Actor(torso=q_network_torso, action_head=eval_q_network_action_head) + + q_lr = make_learning_rate(config.system.q_lr, config, config.system.epochs) + q_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(q_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_x = env.observation_spec().generate_value() + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + + # Initialise q params and optimiser state. + q_online_params = q_network.init(q_net_key, init_x) + q_target_params = q_online_params + q_opt_state = q_optim.init(q_online_params) + + params = QsAndTarget(q_online_params, q_target_params) + opt_states = q_opt_state + + vmapped_q_network_apply_fn = q_network.apply + + # Pack apply and update functions. + apply_fns = vmapped_q_network_apply_fn + update_fns = q_optim.update + + # Create replay buffer + dummy_transition = Transition( + obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + action=jnp.zeros((), dtype=int), + reward=jnp.zeros((), dtype=float), + done=jnp.zeros((), dtype=bool), + next_obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x), + info={"episode_return": 0.0, "episode_length": 0}, + ) + + buffer_fn = fbx.make_item_buffer( + max_length=config.system.buffer_size, + min_length=config.system.batch_size, + sample_batch_size=config.system.batch_size, + add_batches=True, + add_sequences=True, + ) + buffer_fns = (buffer_fn.add, buffer_fn.sample) + buffer_states = buffer_fn.init(dummy_transition) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, buffer_fns, config) + learn = jax.pmap(learn, axis_name="device") + + warmup = get_warmup_fn(env, params, vmapped_q_network_apply_fn, buffer_fn.add, config) + warmup = jax.pmap(warmup, axis_name="device") + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params(TParams=QsAndTarget) + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys, warmup_keys = jax.random.split(key, num=3) + + replicate_learner = (params, opt_states, buffer_states, step_keys, warmup_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, buffer_states, step_keys, warmup_keys = replicate_learner + # Warmup the buffer. + env_states, timesteps, keys, buffer_states = warmup( + env_states, timesteps, buffer_states, warmup_keys + ) + init_learner_state = DQNLearnerState( + params, opt_states, buffer_states, step_keys, env_states, timesteps + ) + + return learn, eval_q_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config=config) + + # PRNG keys. + key, key_e, q_net_key = jax.random.split(jax.random.PRNGKey(config["system"]["seed"]), num=3) + + # Setup learner. + learn, eval_q_network, learner_state = learner_setup(env, (key, q_net_key), config) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + network=eval_q_network, + params=learner_state.params.online, + config=config, + ) + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(-1e6) + best_params = unreplicate_batch_dim(learner_state.params.online) + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + learner_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim( + learner_output.learner_state.params.online + ) # Select only actor params + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + +@hydra.main(config_path="../../configs", config_name="default_ff_mdqn.yaml", version_base="1.2") +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}MDQN experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/utils/loss.py b/stoix/utils/loss.py index ab003932..d5d566d1 100644 --- a/stoix/utils/loss.py +++ b/stoix/utils/loss.py @@ -119,3 +119,36 @@ def categorical_td_learning( td_error = tfd.Categorical(probs=target).cross_entropy(tfd.Categorical(logits=v_logits_tm1)) return jnp.mean(td_error) + + +def munchausen_q_learning( + q_tm1: chex.Array, + q_tm1_target: chex.Array, + a_tm1: chex.Array, + r_t: chex.Array, + d_t: chex.Array, + q_t_target: chex.Array, + entropy_temperature: chex.Array, + munchausen_coefficient: chex.Array, + clip_value_min: chex.Array, + huber_loss_parameter: chex.Array, +) -> chex.Array: + action_one_hot = jax.nn.one_hot(a_tm1, q_tm1.shape[-1]) + q_tm1_a = jnp.sum(q_tm1 * action_one_hot, axis=-1) + # Compute double Q-learning loss. + # Munchausen term : tau * log_pi(a|s) + munchausen_term = entropy_temperature * jax.nn.log_softmax( + q_tm1_target / entropy_temperature, axis=-1 + ) + munchausen_term_a = jnp.sum(action_one_hot * munchausen_term, axis=-1) + munchausen_term_a = jnp.clip(munchausen_term_a, a_min=clip_value_min, a_max=0.0) + + # Soft Bellman operator applied to q + next_v = entropy_temperature * jax.nn.logsumexp(q_t_target / entropy_temperature, axis=-1) + target_q = jax.lax.stop_gradient( + r_t + munchausen_coefficient * munchausen_term_a + d_t * next_v + ) + + batch_loss = rlax.huber_loss(target_q - q_tm1_a, huber_loss_parameter) + batch_loss = jnp.mean(batch_loss) + return batch_loss