Skip to content

Commit

Permalink
feat: add more optional functionality to mpo such as retrace
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 24, 2024
1 parent f4d4fbe commit 12a999b
Show file tree
Hide file tree
Showing 5 changed files with 313 additions and 59 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ exclude = [
".eggs",
]
max-line-length=100
max-cognitive-complexity=11
max-cognitive-complexity=15
import-order-style = "google"
application-import-names = "stoix"
doctests = true
Expand Down
21 changes: 13 additions & 8 deletions stoix/configs/system/ff_mpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@ update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 32 # Number of sgd steps per rollout.
warmup_steps: 256 # Number of steps to collect before training.
buffer_size: 100_000 # size of the replay buffer.
batch_size: 256 # Number of samples to train on per device.
actor_lr: 1e-4 # the learning rate of the policy network optimizer
q_lr: 1e-4 # the learning rate of the Q network network optimizer
buffer_size: 500_000 # size of the replay buffer.
batch_size: 32 # Number of samples to train on per device.
sample_sequence_length: 8 # Number of steps to consider for each element of the batch.
period : 1 # Period of the sampled sequences.
actor_lr: 3e-4 # the learning rate of the policy network optimizer
q_lr: 3e-4 # the learning rate of the Q network network optimizer
dual_lr: 1e-2 # the learning rate of the alpha optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
huber_loss_parameter : 1.0 # Huber loss parameter for Q-value regression.
num_samples: 20 # Number of MPO action samples.
epsilon: 0.1 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature.
epsilon_mean : 0.0025 # KL constraint on the mean of the Gaussian policy, the one associated with the dual variable called alpha_mean.
epsilon_stddev: 1e-6 # KL constraint on the stddev of the Gaussian policy, the one associated with the dual variable called alpha_mean.
epsilon: 0.01 # KL constraint on the non-parametric auxiliary policy, the one associated with the dual variable called temperature.
epsilon_mean : 1e-3 # KL constraint on the mean of the Gaussian policy, the one associated with the dual variable called alpha_mean.
epsilon_stddev: 1e-5 # KL constraint on the stddev of the Gaussian policy, the one associated with the dual variable called alpha_mean.
init_log_temperature: 10. # initial value for the temperature in log-space, note a softplus (rather than an exp) will be used to transform this.
init_log_alpha_mean: 10. # initial value for the alpha_mean in log-space, note a softplus (rather than an exp) will be used to transform this.
init_log_alpha_stddev: 1000. # initial value for the alpha_stddev in log-space, note a softplus (rather than an exp) will be used to transform this.
Expand All @@ -33,3 +34,7 @@ action_penalization: True # whether to use a KL constraint to penalize actions v
epsilon_penalty: 0.001 # KL constraint on the probability of violating the action constraint.
stochastic_policy_eval: True # whether to use a stochastic policy for Q function target evaluation.
policy_eval_num_samples: 128 # Number of samples to use for Q function target evaluation if stochastic.
use_online_policy_to_bootstrap: False # whether to use the online policy to bootstrap the Q function targets.
use_retrace : False # whether to use the retrace algorithm for off-policy correction.
retrace_lambda : 0.95 # the retrace lambda parameter.
n_step_for_sequence_bootstrap : 5 # the number of steps to use for the sequence bootstrap. This is only used if use_retrace is False.
166 changes: 118 additions & 48 deletions stoix/systems/mpo/ff_mpo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import jax.numpy as jnp
import optax
import rlax
from colorama import Fore, Style
from flashbax.buffers.trajectory_buffer import BufferState
from flax.core.frozen_dict import FrozenDict
Expand All @@ -26,16 +27,20 @@
MPOLearnerState,
MPOOptStates,
MPOParams,
SequenceStep,
)
from stoix.systems.mpo.utils import clip_dual_params, mpo_loss
from stoix.systems.q_learning.types import QsAndTarget, Transition
from stoix.systems.q_learning.types import QsAndTarget
from stoix.systems.sac.types import ContinuousQApply
from stoix.types import ActorApply, ExperimentOutput, LearnerFn, LogEnvState
from stoix.utils import make_env as environments
from stoix.utils.checkpointing import Checkpointer
from stoix.utils.jax import unreplicate_batch_dim, unreplicate_n_dims
from stoix.utils.logger import LogEvent, StoixLogger
from stoix.utils.loss import td_learning
from stoix.utils.multistep import (
batch_n_step_bootstrapped_returns,
batch_retrace_continuous,
)
from stoix.utils.total_timestep_checker import check_total_timesteps
from stoix.utils.training import make_learning_rate

Expand All @@ -52,35 +57,37 @@ def warmup(
) -> Tuple[LogEnvState, TimeStep, BufferState, chex.PRNGKey]:
def _env_step(
carry: Tuple[LogEnvState, TimeStep, chex.PRNGKey], _: Any
) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], Transition]:
) -> Tuple[Tuple[LogEnvState, TimeStep, chex.PRNGKey], SequenceStep]:
"""Step the environment."""

env_state, last_timestep, key = carry
# SELECT ACTION
key, policy_key = jax.random.split(key)
actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation)
action = actor_policy.sample(seed=policy_key)
log_prob = actor_policy.log_prob(action)

# STEP ENVIRONMENT
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
done = timestep.last().reshape(-1)
info = timestep.extras["episode_metrics"]
real_next_obs = timestep.extras["final_observation"]

transition = Transition(
last_timestep.observation, action, timestep.reward, done, real_next_obs, info
sequence_step = SequenceStep(
last_timestep.observation, action, timestep.reward, done, log_prob, info
)

return (env_state, timestep, key), transition
return (env_state, timestep, key), sequence_step

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
(env_states, timesteps, keys), traj_batch = jax.lax.scan(
_env_step, (env_states, timesteps, keys), None, config.system.warmup_steps
)

# Add the trajectory to the buffer.
# Swap the batch and time axes.
traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch)
buffer_states = buffer_add_fn(buffer_states, traj_batch)

return env_states, timesteps, keys, buffer_states
Expand All @@ -107,32 +114,33 @@ def get_learner_fn(
buffer_add_fn, buffer_sample_fn = buffer_fns

def _update_step(learner_state: MPOLearnerState, _: Any) -> Tuple[MPOLearnerState, Tuple]:
def _env_step(learner_state: MPOLearnerState, _: Any) -> Tuple[MPOLearnerState, Transition]:
def _env_step(
learner_state: MPOLearnerState, _: Any
) -> Tuple[MPOLearnerState, SequenceStep]:
"""Step the environment."""
params, opt_states, buffer_state, key, env_state, last_timestep = learner_state

# SELECT ACTION
key, policy_key = jax.random.split(key)
actor_policy = actor_apply_fn(params.actor_params.online, last_timestep.observation)

action = actor_policy.sample(seed=policy_key)
log_prob = actor_policy.log_prob(action)

# STEP ENVIRONMENT
env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action)

# LOG EPISODE METRICS
done = timestep.last().reshape(-1)
info = timestep.extras["episode_metrics"]
real_next_obs = timestep.extras["final_observation"]

transition = Transition(
last_timestep.observation, action, timestep.reward, done, real_next_obs, info
sequence_step = SequenceStep(
last_timestep.observation, action, timestep.reward, done, log_prob, info
)

learner_state = MPOLearnerState(
params, opt_states, buffer_state, key, env_state, timestep
)
return learner_state, transition
return learner_state, sequence_step

# STEP ENVIRONMENT FOR ROLLOUT LENGTH
learner_state, traj_batch = jax.lax.scan(
Expand All @@ -142,6 +150,8 @@ def _env_step(learner_state: MPOLearnerState, _: Any) -> Tuple[MPOLearnerState,
params, opt_states, buffer_state, key, env_state, last_timestep = learner_state

# Add the trajectory to the buffer.
# Swap the batch and time axes.
traj_batch = jax.tree_map(lambda x: jnp.swapaxes(x, 0, 1), traj_batch)
buffer_state = buffer_add_fn(buffer_state, traj_batch)

def _update_epoch(update_state: Tuple, _: Any) -> Tuple:
Expand All @@ -152,16 +162,21 @@ def _actor_loss_fn(
dual_params: DualParams,
target_actor_params: FrozenDict,
target_q_params: FrozenDict,
transitions: Transition,
sequence: SequenceStep,
key: chex.PRNGKey,
) -> chex.Array:
online_actor_policy = actor_apply_fn(online_actor_params, transitions.obs)
target_actor_policy = actor_apply_fn(target_actor_params, transitions.obs)
# Reshape the observations to [B*T, ...].
reshaped_obs = jax.tree_map(
lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), sequence.obs
)

online_actor_policy = actor_apply_fn(online_actor_params, reshaped_obs)
target_actor_policy = actor_apply_fn(target_actor_params, reshaped_obs)
target_sampled_actions = target_actor_policy.sample(
seed=key, sample_shape=config.system.num_samples
)
target_sampled_q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))(
target_q_params, transitions.obs, target_sampled_actions
target_q_params, reshaped_obs, target_sampled_actions
)

# Compute the policy and dual loss.
Expand All @@ -184,34 +199,87 @@ def _actor_loss_fn(
def _q_loss_fn(
online_q_params: FrozenDict,
target_q_params: FrozenDict,
online_actor_params: FrozenDict,
target_actor_params: FrozenDict,
transitions: Transition,
sequences: SequenceStep,
rng_key: chex.PRNGKey,
) -> jnp.ndarray:

q_tm1 = q_apply_fn(online_q_params, transitions.obs, transitions.action)
if config.system.stochastic_policy_eval:
next_actions = actor_apply_fn(target_actor_params, transitions.next_obs).sample(
seed=rng_key, sample_shape=config.system.policy_eval_num_samples
)
else:
next_actions = actor_apply_fn(target_actor_params, transitions.next_obs).mode()[
jnp.newaxis, ...
]
q_t = jax.vmap(q_apply_fn, in_axes=(None, None, 0))(
target_q_params, transitions.next_obs, next_actions
)
# Compute the mean over the sampled action dimension.
q_t = jnp.mean(q_t, axis=0)
online_actor_policy = actor_apply_fn(
online_actor_params, sequences.obs
) # [B, T, ...]
target_actor_policy = actor_apply_fn(
target_actor_params, sequences.obs
) # [B, T, ...]
online_q_t = q_apply_fn(online_q_params, sequences.obs, sequence.action) # [B, T]

# Cast and clip rewards.
discount = 1.0 - transitions.done.astype(jnp.float32)
discount = 1.0 - sequence.done.astype(jnp.float32)
d_t = (discount * config.system.gamma).astype(jnp.float32)
r_t = jnp.clip(
transitions.reward, -config.system.max_abs_reward, config.system.max_abs_reward
sequence.reward, -config.system.max_abs_reward, config.system.max_abs_reward
).astype(jnp.float32)

q_loss = td_learning(q_tm1, r_t, d_t, q_t, config.system.huber_loss_parameter)
# Policy to use for policy evaluation and bootstrapping.
if config.system.use_online_policy_to_bootstrap:
policy_to_evaluate = online_actor_policy
else:
policy_to_evaluate = target_actor_policy

# Action(s) to use for policy evaluation; shape [N, B, T].
if config.system.stochastic_policy_eval:
a_evaluation = policy_to_evaluate.sample(
seed=rng_key, sample_shape=config.system.policy_eval_num_samples
) # [N, B, T, ...]
else:
a_evaluation = policy_to_evaluate.mode()[jnp.newaxis, ...] # [N=1, B, T, ...]

# Add a stopgrad in case we use the online policy for evaluation.
a_evaluation = jax.lax.stop_gradient(a_evaluation)

# Compute the Q-values for the next state-action pairs; [N, B, T].
q_values = jax.vmap(q_apply_fn, in_axes=(None, None, 0))(
target_q_params, sequences.obs, a_evaluation
)

# When policy_eval_stochastic == True, this corresponds to expected SARSA.
# Otherwise, the mean is a no-op.
v_t = jnp.mean(q_values, axis=0) # [B, T]

if config.system.use_retrace:
# Compute the log-rhos for the retrace targets.
log_rhos = target_actor_policy.log_prob(sequences.action) - sequences.log_prob

# Compute target Q-values
target_q_t = q_apply_fn(
target_q_params, sequences.obs, sequences.action
) # [B, T]

# Compute retrace targets.
# These targets use the rewards and discounts as in normal TD-learning but
# they use a mix of bootstrapped values V(s') and Q(s', a'), weighing the
# latter based on how likely a' is under the current policy (s' and a' are
# samples from replay).
# See [Munos et al., 2016](https://arxiv.org/abs/1606.02647) for more.
retrace_error = batch_retrace_continuous(
online_q_t[:, :-1],
target_q_t[:, 1:-1],
v_t[:, 1:],
r_t[:, :-1],
d_t[:, :-1],
log_rhos[:, 1:-1],
config.system.retrace_lambda,
)
q_loss = rlax.l2_loss(retrace_error).mean()
else:
n_step_value_target = batch_n_step_bootstrapped_returns(
r_t[:, :-1],
d_t[:, :-1],
v_t[:, 1:],
config.system.n_step_for_sequence_bootstrap,
)
td_error = online_q_t[:, :-1] - n_step_value_target
q_loss = rlax.l2_loss(td_error).mean()

loss_info = {
"q_loss": q_loss,
Expand All @@ -223,9 +291,9 @@ def _q_loss_fn(

key, sample_key, actor_key, q_key = jax.random.split(key, num=4)

# SAMPLE TRANSITIONS
transition_sample = buffer_sample_fn(buffer_state, sample_key)
transitions: Transition = transition_sample.experience
# SAMPLE SEQUENCES
sequence_sample = buffer_sample_fn(buffer_state, sample_key)
sequence: SequenceStep = sequence_sample.experience

# CALCULATE ACTOR AND DUAL LOSS
actor_dual_grad_fn = jax.grad(_actor_loss_fn, argnums=(0, 1), has_aux=True)
Expand All @@ -234,7 +302,7 @@ def _q_loss_fn(
params.dual_params,
params.actor_params.target,
params.q_params.target,
transitions,
sequence,
actor_key,
)

Expand All @@ -243,8 +311,9 @@ def _q_loss_fn(
q_grads, q_loss_info = q_grad_fn(
params.q_params.online,
params.q_params.target,
params.actor_params.online,
params.actor_params.target,
transitions,
sequence,
q_key,
)

Expand Down Expand Up @@ -441,24 +510,25 @@ def learner_setup(
update_fns = (actor_optim.update, q_optim.update, dual_optim.update)

# Create replay buffer
dummy_transition = Transition(
dummy_sequence_step = SequenceStep(
obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
action=jnp.zeros((action_dim), dtype=float),
reward=jnp.zeros((), dtype=float),
done=jnp.zeros((), dtype=bool),
next_obs=jax.tree_util.tree_map(lambda x: x.squeeze(0), init_x),
log_prob=jnp.zeros((), dtype=float),
info={"episode_return": 0.0, "episode_length": 0},
)

buffer_fn = fbx.make_item_buffer(
max_length=config.system.buffer_size,
min_length=config.system.batch_size,
buffer_fn = fbx.make_trajectory_buffer(
max_size=config.system.buffer_size,
min_length_time_axis=config.system.sample_sequence_length,
sample_batch_size=config.system.batch_size,
add_batches=True,
add_sequences=True,
sample_sequence_length=config.system.sample_sequence_length,
period=config.system.period,
add_batch_size=config.arch.num_envs,
)
buffer_fns = (buffer_fn.add, buffer_fn.sample)
buffer_states = buffer_fn.init(dummy_transition)
buffer_states = buffer_fn.init(dummy_sequence_step)

# Get batched iterated update and replicate it to pmap it over cores.
learn = get_learner_fn(env, apply_fns, update_fns, buffer_fns, config)
Expand Down
11 changes: 10 additions & 1 deletion stoix/systems/mpo/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Dict, Optional, Union

import chex
import optax
Expand All @@ -11,6 +11,15 @@
from stoix.types import LogEnvState


class SequenceStep(NamedTuple):
obs: chex.ArrayTree
action: chex.Array
reward: chex.Array
done: chex.Array
log_prob: chex.Array
info: Dict


class ActorAndTarget(NamedTuple):
online: FrozenDict
target: FrozenDict
Expand Down
Loading

0 comments on commit 12a999b

Please sign in to comment.