Skip to content

Commit

Permalink
feat: add qr-dqn
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 20, 2024
1 parent 0bafd6e commit c19dcbf
Show file tree
Hide file tree
Showing 6 changed files with 706 additions and 0 deletions.
7 changes: 7 additions & 0 deletions stoix/configs/default_ff_qr_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_dqn
- arch: anakin
- system: ff_qr_dqn
- network: mlp_qr_dqn
- env: gymnax/cartpole
- _self_
9 changes: 9 additions & 0 deletions stoix/configs/network/mlp_qr_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# ---MLP DQN Networks---
actor_network:
pre_torso:
_target_: stoix.networks.torso.MLPTorso
layer_sizes: [256, 256]
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.QuantileDiscreteQNetwork
24 changes: 24 additions & 0 deletions stoix/configs/system/ff_qr_dqn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# --- Defaults FF-DQN ---

total_timesteps: 1e8 # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: ~ # Number of updates
seed: 42

# --- RL hyperparameters ---
update_batch_size: 1 # Number of vectorised gradient updates per device.
rollout_length: 8 # Number of environment steps per vectorised environment.
epochs: 16 # Number of sgd steps per rollout.
warmup_steps: 128 # Number of steps to collect before training.
buffer_size: 100_000 # size of the replay buffer.
batch_size: 128 # Number of samples to train on per device.
q_lr: 1e-5 # the learning rate of the Q network network optimizer
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
training_epsilon: 0.1 # epsilon for the epsilon-greedy policy during training
evaluation_epsilon: 0.00 # epsilon for the epsilon-greedy policy during evaluation
max_abs_reward : 1000.0 # maximum absolute reward value
huber_loss_parameter: 1.0 # parameter for the huber loss
num_quantiles: 200 # number of quantiles
17 changes: 17 additions & 0 deletions stoix/networks/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,20 @@ def __call__(
q_value = jnp.sum(q_dist * atoms, axis=-1)
atoms = jnp.broadcast_to(atoms, (*q_value.shape, self.num_atoms))
return q_value, q_logits, atoms


class QuantileDiscreteQNetwork(nn.Module):
action_dim: int
epsilon: float
num_quantiles: int
kernel_init: Initializer = lecun_normal()

@nn.compact
def __call__(self, embedding: chex.Array) -> Tuple[distrax.EpsilonGreedy, chex.Array]:
q_logits = nn.Dense(self.action_dim * self.num_quantiles, kernel_init=self.kernel_init)(
embedding
)
q_dist = jnp.reshape(q_logits, (-1, self.action_dim, self.num_quantiles))
q_values = jnp.mean(q_dist, axis=-1)
q_values = jax.lax.stop_gradient(q_values)
return distrax.EpsilonGreedy(preferences=q_values, epsilon=self.epsilon), q_dist
Loading

0 comments on commit c19dcbf

Please sign in to comment.