Skip to content

Commit

Permalink
Merge pull request #2 from EdanToledo/feat/ddpg-variants
Browse files Browse the repository at this point in the history
Feat/ddpg variants
  • Loading branch information
EdanToledo authored Feb 19, 2024
2 parents 8e9338a + 9a22477 commit e4d57d7
Show file tree
Hide file tree
Showing 13 changed files with 1,451 additions and 25 deletions.
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_ddpg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_ddpg
- arch: anakin
- system: ff_ddpg
- network: mlp_ddpg
- env: brax/ant
- _self_
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_td3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_td3
- arch: anakin
- system: ff_td3
- network: mlp_td3
- env: brax/ant
- _self_
4 changes: 4 additions & 0 deletions stoix/configs/logger/ff_ddpg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_ddpg
4 changes: 4 additions & 0 deletions stoix/configs/logger/ff_td3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_td3
22 changes: 22 additions & 0 deletions stoix/configs/network/mlp_ddpg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ---MLP D4PG Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.LinearOutputHead
post_processor:
_target_: stoix.networks.postprocessors.TanhToSpec

q_network:
input_layer:
_target_: stoix.networks.inputs.ObservationActionInput
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
22 changes: 22 additions & 0 deletions stoix/configs/network/mlp_td3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# ---MLP D4PG Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256, 256]
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.LinearOutputHead
post_processor:
_target_: stoix.networks.postprocessors.TanhToSpec

q_network:
input_layer:
_target_: stoix.networks.inputs.ObservationActionInput
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
critic_head:
_target_: stoix.networks.heads.ScalarCriticHead
23 changes: 23 additions & 0 deletions stoix/configs/system/ff_ddpg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# --- Defaults FF-D4PG ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 42

# --- RL hyperparameters ---
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 256 # Number of steps to collect before training.
buffer_size: 1_000_000 # size of the replay buffer.
batch_size: 256 # Number of samples to train on per device.
actor_lr: 1e-4 # the learning rate of the policy network optimizer
q_lr: 1e-4 # the learning rate of the Q network network optimizer
tau: 0.01 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
exploration_sigma : 0.1 # standard deviation of the exploration noise
huber_loss_parameter: 1.0 # parameter for the huber loss
25 changes: 25 additions & 0 deletions stoix/configs/system/ff_td3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# --- Defaults FF-D4PG ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 42

# --- RL hyperparameters ---
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 256 # Number of steps to collect before training.
buffer_size: 1_000_000 # size of the replay buffer.
batch_size: 256 # Number of samples to train on per device.
actor_lr: 1e-4 # the learning rate of the policy network optimizer
q_lr: 1e-4 # the learning rate of the Q network network optimizer
tau: 0.01 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
max_abs_reward : 20_000 # maximum absolute reward value
exploration_sigma : 0.1 # standard deviation of the exploration noise
policy_noise: 0.2 # standard deviation of the policy noise
policy_frequency: 2 # frequency of the policy update in the TD3 algorithm (delayed policy update)
noise_clip: 0.5 # noise clip parameter of the Target Policy Smoothing Regularization
40 changes: 20 additions & 20 deletions stoix/systems/d4pg/ff_d4pg.py → stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from stoix.evaluator import evaluator_setup
from stoix.networks.base import CompositeNetwork
from stoix.networks.base import FeedForwardActor as Actor
from stoix.systems.d4pg.types import (
from stoix.systems.ddpg.types import (
ActorAndTarget,
D4PGLearnerState,
D4PGOptStates,
D4PGParams,
DDPGLearnerState,
DDPGOptStates,
DDPGParams,
)
from stoix.systems.q_learning.types import QsAndTarget, Transition
from stoix.systems.sac.types import ContinuousQApply
Expand All @@ -55,7 +55,7 @@

def get_default_behavior_policy(config: DictConfig, actor_apply_fn: ActorApply) -> Callable:
def behavior_policy(
params: D4PGParams, observation: Observation, key: chex.PRNGKey
params: DDPGParams, observation: Observation, key: chex.PRNGKey
) -> chex.Array:
action = actor_apply_fn(params, observation)
if config.system.exploration_sigma != 0:
Expand All @@ -67,7 +67,7 @@ def behavior_policy(

def get_warmup_fn(
env: Environment,
params: D4PGParams,
params: DDPGParams,
actor_apply_fn: ActorApply,
buffer_add_fn: Callable,
config: DictConfig,
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_learner_fn(
update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn],
buffer_fns: Tuple[Callable, Callable],
config: DictConfig,
) -> LearnerFn[D4PGLearnerState]:
) -> LearnerFn[DDPGLearnerState]:
"""Get the learner function."""

# Get apply and update functions for actor and critic networks.
Expand All @@ -136,10 +136,10 @@ def get_learner_fn(
buffer_add_fn, buffer_sample_fn = buffer_fns
exploratory_actor_apply = get_default_behavior_policy(config, actor_apply_fn)

def _update_step(learner_state: D4PGLearnerState, _: Any) -> Tuple[D4PGLearnerState, Tuple]:
def _update_step(learner_state: DDPGLearnerState, _: Any) -> Tuple[DDPGLearnerState, Tuple]:
def _env_step(
learner_state: D4PGLearnerState, _: Any
) -> Tuple[D4PGLearnerState, Transition]:
learner_state: DDPGLearnerState, _: Any
) -> Tuple[DDPGLearnerState, Transition]:
"""Step the environment."""
params, opt_states, buffer_state, key, env_state, last_timestep = learner_state

Expand All @@ -161,7 +161,7 @@ def _env_step(
last_timestep.observation, action, timestep.reward, done, real_next_obs, info
)

learner_state = D4PGLearnerState(
learner_state = DDPGLearnerState(
params, opt_states, buffer_state, key, env_state, timestep
)
return learner_state, transition
Expand Down Expand Up @@ -292,8 +292,8 @@ def _actor_loss_fn(
q_new_params = QsAndTarget(q_new_online_params, new_target_q_params)

# PACK NEW PARAMS AND OPTIMISER STATE
new_params = D4PGParams(actor_new_params, q_new_params)
new_opt_state = D4PGOptStates(actor_new_opt_state, q_new_opt_state)
new_params = DDPGParams(actor_new_params, q_new_params)
new_opt_state = DDPGOptStates(actor_new_opt_state, q_new_opt_state)

# PACK LOSS INFO
loss_info = {
Expand All @@ -311,13 +311,13 @@ def _actor_loss_fn(
)

params, opt_states, buffer_state, key = update_state
learner_state = D4PGLearnerState(
learner_state = DDPGLearnerState(
params, opt_states, buffer_state, key, env_state, last_timestep
)
metric = traj_batch.info
return learner_state, (metric, loss_info)

def learner_fn(learner_state: D4PGLearnerState) -> ExperimentOutput[D4PGLearnerState]:
def learner_fn(learner_state: DDPGLearnerState) -> ExperimentOutput[DDPGLearnerState]:
"""Learner function.
This function represents the learner, it updates the network parameters
Expand All @@ -342,7 +342,7 @@ def learner_fn(learner_state: D4PGLearnerState) -> ExperimentOutput[D4PGLearnerS

def learner_setup(
env: Environment, keys: chex.Array, config: DictConfig
) -> Tuple[LearnerFn[D4PGLearnerState], Actor, D4PGLearnerState]:
) -> Tuple[LearnerFn[DDPGLearnerState], Actor, DDPGLearnerState]:
"""Initialise learner_fn, network, optimiser, environment and states."""
# Get available TPU cores.
n_devices = len(jax.devices())
Expand Down Expand Up @@ -406,8 +406,8 @@ def learner_setup(

q_opt_state = q_optim.init(q_online_params)

params = D4PGParams(actor_params, q_params)
opt_states = D4PGOptStates(actor_opt_state, q_opt_state)
params = DDPGParams(actor_params, q_params)
opt_states = DDPGOptStates(actor_opt_state, q_opt_state)

vmapped_actor_network_apply_fn = actor_network.apply
vmapped_q_network_apply_fn = q_network.apply
Expand Down Expand Up @@ -464,7 +464,7 @@ def learner_setup(
**config.logger.checkpointing.load_args, # Other checkpoint args
)
# Restore the learner state from the checkpoint
restored_params, _ = loaded_checkpoint.restore_params(TParams=D4PGParams)
restored_params, _ = loaded_checkpoint.restore_params(TParams=DDPGParams)
# Update the params
params = restored_params

Expand All @@ -486,7 +486,7 @@ def learner_setup(
env_states, timesteps, keys, buffer_states = warmup(
env_states, timesteps, buffer_states, warmup_keys
)
init_learner_state = D4PGLearnerState(
init_learner_state = DDPGLearnerState(
params, opt_states, buffer_states, step_keys, env_states, timesteps
)

Expand Down
Loading

0 comments on commit e4d57d7

Please sign in to comment.