Skip to content

Commit

Permalink
Merge pull request #8 from EdanToledo/chore/use_tfp_deterministic_dis…
Browse files Browse the repository at this point in the history
…tribution

chore: use tfd dist instead of created one
  • Loading branch information
EdanToledo authored Feb 23, 2024
2 parents 9b48287 + 2defcd5 commit 97d7f36
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 146 deletions.
2 changes: 1 addition & 1 deletion stoix/configs/default_ff_td3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@ defaults:
- logger: ff_td3
- arch: anakin
- system: ff_td3
- network: mlp_td3
- network: mlp_ddpg
- env: brax/ant
- _self_
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_d4pg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ actor_network:
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.LinearOutputHead
_target_: stoix.networks.heads.DeterministicHead
post_processor:
_target_: stoix.networks.postprocessors.TanhToSpec
_target_: stoix.networks.postprocessors.ScalePostProcessor

q_network:
input_layer:
Expand Down
4 changes: 2 additions & 2 deletions stoix/configs/network/mlp_ddpg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ actor_network:
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.LinearOutputHead
_target_: stoix.networks.heads.DeterministicHead
post_processor:
_target_: stoix.networks.postprocessors.TanhToSpec
_target_: stoix.networks.postprocessors.ScalePostProcessor

q_network:
input_layer:
Expand Down
22 changes: 0 additions & 22 deletions stoix/configs/network/mlp_td3.yaml

This file was deleted.

15 changes: 1 addition & 14 deletions stoix/networks/distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Optional, Sequence
from typing import Any, Optional

import chex
import jax.numpy as jnp
Expand All @@ -8,7 +8,6 @@
from tensorflow_probability.substrates.jax.distributions import (
Categorical,
Distribution,
MultivariateNormalDiag,
TransformedDistribution,
)

Expand Down Expand Up @@ -73,18 +72,6 @@ def _parameter_properties(cls, dtype: Optional[Any], num_classes: Any = None) ->
return td_properties


class DeterministicNormalDistribution(MultivariateNormalDiag):
"""Deterministic normal distribution. Always returns the mean."""

def sample(
self,
seed: chex.PRNGKey = None,
sample_shape: Sequence[int] = (),
name: str = "sample",
) -> chex.Array:
return self.loc


@tf_tfp.experimental.auto_composite_tensor
class DiscreteValuedTfpDistribution(Categorical):
"""This is a generalization of a categorical distribution.
Expand Down
5 changes: 3 additions & 2 deletions stoix/networks/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flax.linen.initializers import Initializer, lecun_normal, orthogonal
from tensorflow_probability.substrates.jax.distributions import (
Categorical,
Deterministic,
Independent,
MultivariateNormalDiag,
Normal,
Expand Down Expand Up @@ -70,7 +71,7 @@ def __call__(self, embedding: chex.Array) -> distrax.DistributionLike:
return MultivariateNormalDiag(loc=loc, scale_diag=scale)


class LinearOutputHead(nn.Module):
class DeterministicHead(nn.Module):
action_dim: int
kernel_init: Initializer = orthogonal(0.01)

Expand All @@ -79,7 +80,7 @@ def __call__(self, embedding: chex.Array) -> chex.Array:

x = nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)

return x
return Deterministic(x)


class ScalarCriticHead(nn.Module):
Expand Down
79 changes: 54 additions & 25 deletions stoix/networks/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,70 @@
from functools import partial
from typing import Any, Callable, Sequence

import chex
import jax
import jax.numpy as jnp
from flax import linen as nn
from tensorflow_probability.substrates.jax.distributions import Distribution

# Different to bijectors, postprocessors simply wrap the sample and mode methods of a distribution.

class RescaleToSpec(nn.Module):
minimum: float
maximum: float

@nn.compact
def __call__(self, inputs: chex.Array) -> chex.Array:
scale = self.maximum - self.minimum
offset = self.minimum
inputs = 0.5 * (inputs + 1.0) # [0, 1]
output = inputs * scale + offset # [minimum, maximum]
return output
class PostProcessedDistribution(Distribution):
"""A distribution that applies a postprocessing function to the samples and mode.
This is useful for transforming the output of a distribution to a different space, such as
rescaling the output of a tanh-transformed Normal distribution to a different range. However,
this is not the same as a bijector, which also transforms the density function of the
distribution. This is only useful for transforming the samples and mode of the distribution.
For example, for an algorithm that requires taking the log probability of the samples, the
distribution should be transformed using a bijector, not a postprocessor."""

class TanhToSpec(nn.Module):
minimum: float
maximum: float
def __init__(
self, distribution: Distribution, postprocessor: Callable[[chex.Array], chex.Array]
):
self.distribution = distribution
self.postprocessor = postprocessor

@nn.compact
def __call__(self, inputs: chex.Array) -> chex.Array:
scale = self.maximum - self.minimum
offset = self.minimum
inputs = jax.nn.tanh(inputs) # [-1, 1]
inputs = 0.5 * (inputs + 1.0) # [0, 1]
output = inputs * scale + offset # [minimum, maximum]
return output
def sample(self, seed: chex.PRNGKey, sample_shape: Sequence[int] = ()) -> chex.Array:
return self.postprocessor(self.distribution.sample(seed=seed, sample_shape=sample_shape))

def mode(self) -> chex.Array:
return self.postprocessor(self.distribution.mode())

def __getattr__(self, name: str) -> Any:
if name == "__setstate__":
raise AttributeError(name)
return getattr(self.distribution, name)


def rescale_to_spec(inputs: chex.Array, minimum: float, maximum: float) -> chex.Array:
scale = maximum - minimum
offset = minimum
inputs = 0.5 * (inputs + 1.0) # [0, 1]
output = inputs * scale + offset # [minimum, maximum]
return output


def clip_to_spec(inputs: chex.Array, minimum: float, maximum: float) -> chex.Array:
return jnp.clip(inputs, minimum, maximum)


def tanh_to_spec(inputs: chex.Array, minimum: float, maximum: float) -> chex.Array:
scale = maximum - minimum
offset = minimum
inputs = jax.nn.tanh(inputs) # [-1, 1]
inputs = 0.5 * (inputs + 1.0) # [0, 1]
output = inputs * scale + offset # [minimum, maximum]
return output


class ClipToSpec(nn.Module):
class ScalePostProcessor(nn.Module):
minimum: float
maximum: float
scale_fn: Callable[[chex.Array, float, float], chex.Array]

@nn.compact
def __call__(self, inputs: chex.Array) -> chex.Array:
output = jnp.clip(inputs, self.minimum, self.maximum)
return output
def __call__(self, distribution: Distribution) -> Distribution:
post_processor = partial(self.scale_fn, minimum=self.minimum, maximum=self.maximum)
return PostProcessedDistribution(distribution, post_processor)
38 changes: 12 additions & 26 deletions stoix/systems/ddpg/ff_d4pg.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
import copy
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, Tuple

import rlax

from stoix.networks.distributions import DeterministicNormalDistribution

if TYPE_CHECKING:
from dataclasses import dataclass
else:
from chex import dataclass
from typing import Any, Callable, Dict, Tuple

import chex
import flashbax as fbx
Expand All @@ -18,6 +9,7 @@
import jax
import jax.numpy as jnp
import optax
import rlax
from colorama import Fore, Style
from flashbax.buffers.trajectory_buffer import BufferState
from flax.core.frozen_dict import FrozenDict
Expand All @@ -29,6 +21,7 @@
from stoix.evaluator import evaluator_setup
from stoix.networks.base import CompositeNetwork
from stoix.networks.base import FeedForwardActor as Actor
from stoix.networks.postprocessors import tanh_to_spec
from stoix.systems.ddpg.types import (
ActorAndTarget,
DDPGLearnerState,
Expand Down Expand Up @@ -57,7 +50,7 @@ def get_default_behavior_policy(config: DictConfig, actor_apply_fn: ActorApply)
def behavior_policy(
params: DDPGParams, observation: Observation, key: chex.PRNGKey
) -> chex.Array:
action = actor_apply_fn(params, observation)
action = actor_apply_fn(params, observation).mode()
if config.system.exploration_sigma != 0:
action = rlax.add_gaussian_noise(key, action, config.system.exploration_sigma)
return action
Expand Down Expand Up @@ -189,7 +182,7 @@ def _q_loss_fn(
_, q_logits_tm1, q_atoms_tm1 = q_apply_fn(
q_params, transitions.obs, transitions.action
)
next_action = actor_apply_fn(target_actor_params, transitions.next_obs)
next_action = actor_apply_fn(target_actor_params, transitions.next_obs).mode()
_, q_logits_t, q_atoms_t = q_apply_fn(
target_q_params, transitions.next_obs, next_action
)
Expand Down Expand Up @@ -222,7 +215,7 @@ def _actor_loss_fn(
transitions: Transition,
) -> chex.Array:
o_t = transitions.obs
a_t = actor_apply_fn(actor_params, o_t)
a_t = actor_apply_fn(actor_params, o_t).mode()
q_value, _, _ = q_apply_fn(q_params, o_t, a_t)

actor_loss = -jnp.mean(q_value)
Expand Down Expand Up @@ -360,10 +353,14 @@ def learner_setup(
config.network.actor_network.action_head, action_dim=action_dim
)
action_head_post_processor = hydra.utils.instantiate(
config.network.actor_network.post_processor, minimum=-1.0, maximum=1.0
config.network.actor_network.post_processor,
minimum=-1.0,
maximum=1.0,
scale_fn=tanh_to_spec,
)
actor_action_head = CompositeNetwork([actor_action_head, action_head_post_processor])
actor_network = Actor(torso=actor_torso, action_head=actor_action_head)

q_network_input = hydra.utils.instantiate(config.network.q_network.input_layer)
q_network_torso = hydra.utils.instantiate(config.network.q_network.pre_torso)
q_network_head = hydra.utils.instantiate(
Expand Down Expand Up @@ -493,15 +490,6 @@ def learner_setup(
return learn, actor_network, init_learner_state


@dataclass
class EvalActorWrapper:
actor: Actor

def apply(self, params: FrozenDict, x: Observation) -> DeterministicNormalDistribution:
action = self.actor.apply(params, x)
return DeterministicNormalDistribution(loc=action)


def run_experiment(_config: DictConfig) -> None:
"""Runs experiment."""
config = copy.deepcopy(_config)
Expand All @@ -526,13 +514,11 @@ def run_experiment(_config: DictConfig) -> None:
env, (key, actor_net_key, q_net_key), config
)

eval_actor_network = EvalActorWrapper(actor=actor_network)

# Setup evaluator.
evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup(
eval_env=eval_env,
key_e=key_e,
network=eval_actor_network,
network=actor_network,
params=learner_state.params.actor_params.online,
config=config,
)
Expand Down
Loading

0 comments on commit 97d7f36

Please sign in to comment.