Skip to content

Commit

Permalink
chore: change from distrax to tensorflow probability
Browse files Browse the repository at this point in the history
  • Loading branch information
EdanToledo committed Feb 19, 2024
1 parent 0ae514d commit 93a9e83
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 81 deletions.
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ protobuf~=3.20
rlax
tdqm
tensorboard_logger
tensorflow
tensorflow_probability
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@main
2 changes: 1 addition & 1 deletion stoix/configs/network/mlp_continuous.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ actor_network:
use_layer_norm: True
activation: silu
action_head:
_target_: stoix.networks.heads.TanhMultivariateNormalDiagHead
_target_: stoix.networks.heads.NormalTanhDistributionHead

critic_network:
pre_torso:
Expand Down
2 changes: 1 addition & 1 deletion stoix/configs/network/mlp_sac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ actor_network:
use_layer_norm: False
activation: silu
action_head:
_target_: stoix.networks.heads.TanhMultivariateNormalDiagHead
_target_: stoix.networks.heads.NormalTanhDistributionHead

q_network:
input_layer:
Expand Down
2 changes: 1 addition & 1 deletion stoix/networks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __call__(
self,
policy_hidden_state: chex.Array,
observation_done: RNNObservation,
) -> Tuple[chex.Array, distrax.Categorical]:
) -> Tuple[chex.Array, distrax.DistributionLike]:
"""Forward pass."""
observation, done = observation_done

Expand Down
190 changes: 158 additions & 32 deletions stoix/networks/distributions.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,167 @@
from typing import Sequence, Union
from typing import Any, Optional, Sequence

import jax
import chex
import jax.numpy as jnp
from chex import Array, PRNGKey
from distrax import MultivariateNormalDiag
import numpy as np
import tensorflow_probability as tf_tfp
import tensorflow_probability.substrates.jax as tfp
from tensorflow_probability.substrates.jax.distributions import (
Categorical,
Distribution,
MultivariateNormalDiag,
TransformedDistribution,
)


class TanhMultivariateNormalDiag(MultivariateNormalDiag):
"""TanhMultivariateNormalDiag"""
class TanhTransformedDistribution(TransformedDistribution):
"""Distribution followed by tanh."""

def sample(
self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = ()
) -> Array:
"""Sample from the distribution and apply the tanh."""
sample = super().sample(seed=seed, sample_shape=sample_shape)
return jnp.tanh(sample)

def sample_unprocessed(
self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = ()
) -> Array:
"""Sample from the distribution without applying the tanh."""
sample = super().sample(seed=seed, sample_shape=sample_shape)
return sample

def log_prob_of_unprocessed(self, value: Array) -> Array:
"""Log probability of a value in transformed distribution.
Value is the unprocessed value. i.e. the sample before the tanh."""
log_prob = super().log_prob(value) - jnp.sum(
2.0 * (jnp.log(2.0) - value - jax.nn.softplus(-2.0 * value)),
axis=-1,
def __init__(
self, distribution: Distribution, threshold: float = 0.999, validate_args: bool = False
) -> None:
"""Initialize the distribution.
Args:
distribution: The distribution to transform.
threshold: Clipping value of the action when computing the logprob.
validate_args: Passed to super class.
"""
super().__init__(
distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args
)
# Computes the log of the average probability distribution outside the
# clipping range, i.e. on the interval [-inf, -atanh(threshold)] for
# log_prob_left and [atanh(threshold), inf] for log_prob_right.
self._threshold = threshold
inverse_threshold = self.bijector.inverse(threshold)
# average(pdf) = p/epsilon
# So log(average(pdf)) = log(p) - log(epsilon)
log_epsilon = jnp.log(1.0 - threshold)
# Those 2 values are differentiable w.r.t. model parameters, such that the
# gradient is defined everywhere.
self._log_prob_left = self.distribution.log_cdf(-inverse_threshold) - log_epsilon
self._log_prob_right = (
self.distribution.log_survival_function(inverse_threshold) - log_epsilon
)
return log_prob

def log_prob(self, event: chex.Array) -> chex.Array:
# Without this clip there would be NaNs in the inner tf.where and that
# causes issues for some reasons.
event = jnp.clip(event, -self._threshold, self._threshold)
# The inverse image of {threshold} is the interval [atanh(threshold), inf]
# which has a probability of "log_prob_right" under the given distribution.
return jnp.where(
event <= -self._threshold,
self._log_prob_left,
jnp.where(event >= self._threshold, self._log_prob_right, super().log_prob(event)),
)

def mode(self) -> chex.Array:
return self.bijector.forward(self.distribution.mode())

def entropy(self, seed: chex.PRNGKey = None) -> chex.Array:
# We return an estimation using a single sample of the log_det_jacobian.
# We can still do some backpropagation with this estimate.
return self.distribution.entropy() + self.bijector.forward_log_det_jacobian(
self.distribution.sample(seed=seed), event_ndims=0
)

@classmethod
def _parameter_properties(cls, dtype: Optional[Any], num_classes: Any = None) -> Any:
td_properties = super()._parameter_properties(dtype, num_classes=num_classes)
del td_properties["bijector"]
return td_properties


class DeterministicNormalDistribution(MultivariateNormalDiag):
"""Deterministic normal distribution. Always returns the mean."""

class DeterministicDistribution(MultivariateNormalDiag):
def sample(
self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = ()
) -> Array:
sample = self.loc
return sample
self,
seed: chex.PRNGKey = None,
sample_shape: Sequence[int] = (),
name: str = "sample",
) -> chex.Array:
return self.loc


@tf_tfp.experimental.auto_composite_tensor
class DiscreteValuedTfpDistribution(Categorical):
"""This is a generalization of a categorical distribution.
The support for the DiscreteValued distribution can be any real valued range,
whereas the categorical distribution has support [0, n_categories - 1] or
[1, n_categories]. This generalization allows us to take the mean of the
distribution over its support.
"""

def __init__(
self,
values: chex.Array,
logits: Optional[chex.Array] = None,
probs: Optional[chex.Array] = None,
name: str = "DiscreteValuedDistribution",
):
"""Initialization.
Args:
values: Values making up support of the distribution. Should have a shape
compatible with logits.
logits: An N-D Tensor, N >= 1, representing the log probabilities of a set
of Categorical distributions. The first N - 1 dimensions index into a
batch of independent distributions and the last dimension indexes into
the classes.
probs: An N-D Tensor, N >= 1, representing the probabilities of a set of
Categorical distributions. The first N - 1 dimensions index into a batch
of independent distributions and the last dimension represents a vector
of probabilities for each class. Only one of logits or probs should be
passed in.
name: Name of the distribution object.
"""
parameters = dict(locals())
self._values = np.asarray(values)

if logits is not None:
logits = jnp.asarray(logits)
chex.assert_shape(logits, (..., *self._values.shape))

if probs is not None:
probs = jnp.asarray(probs)
chex.assert_shape(probs, (..., *self._values.shape))

super().__init__(logits=logits, probs=probs, name=name)

self._parameters = parameters

@property
def values(self) -> chex.Array:
return self._values

@classmethod
def _parameter_properties(cls, dtype: np.dtype, num_classes: Any = None) -> Any:
return {
"values": tfp.util.ParameterProperties(
event_ndims=None, shape_fn=lambda shape: (num_classes,), specifies_shape=True
),
"logits": tfp.util.ParameterProperties(event_ndims=1),
"probs": tfp.util.ParameterProperties(event_ndims=1, is_preferred=False),
}

def _sample_n(self, key: chex.PRNGKey, n: int) -> chex.Array:
indices = super()._sample_n(key=key, n=n)
return jnp.take_along_axis(self._values, indices, axis=-1)

def mean(self) -> chex.Array:
"""Overrides the Categorical mean by incorporating category values."""
return jnp.sum(self.probs_parameter() * self._values, axis=-1)

def variance(self) -> chex.Array:
"""Overrides the Categorical variance by incorporating category values."""
dist_squared = jnp.square(jnp.expand_dims(self.mean(), -1) - self._values)
return jnp.sum(self.probs_parameter() * dist_squared, axis=-1)

def _event_shape(self) -> chex.Array:
return jnp.zeros((), dtype=jnp.int32)

def _event_shape_tensor(self) -> chex.Array:
return []
106 changes: 95 additions & 11 deletions stoix/networks/heads.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import chex
import distrax
Expand All @@ -7,42 +7,67 @@
import numpy as np
from flax import linen as nn
from flax.linen.initializers import Initializer, lecun_normal, orthogonal
from tensorflow_probability.substrates.jax.distributions import (
Categorical,
Independent,
MultivariateNormalDiag,
Normal,
)

from stoix.networks.distributions import TanhMultivariateNormalDiag
from stoix.networks.distributions import (
DiscreteValuedTfpDistribution,
TanhTransformedDistribution,
)


class CategoricalHead(nn.Module):
action_dim: Union[int, Sequence[int]]
kernel_init: Initializer = orthogonal(0.01)

@nn.compact
def __call__(self, embedding: chex.Array) -> distrax.Categorical:
def __call__(self, embedding: chex.Array) -> Categorical:

logits = nn.Dense(np.prod(self.action_dim), kernel_init=self.kernel_init)(embedding)

if not isinstance(self.action_dim, int):
logits = logits.reshape(self.action_dim)

return distrax.Categorical(logits=logits)
return Categorical(logits=logits)


class TanhMultivariateNormalDiagHead(nn.Module):
class NormalTanhDistributionHead(nn.Module):

action_dim: int
init_scale: float = 0.3
min_scale: float = 1e-3
kernel_init: Initializer = orthogonal(0.01)

@nn.compact
def __call__(self, embedding: chex.Array) -> TanhMultivariateNormalDiag:
def __call__(self, embedding: chex.Array) -> Independent:

loc = nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)
scale = jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding))
scale = (
jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding))
+ self.min_scale
)
distribution = Normal(loc=loc, scale=scale)

return Independent(TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1)


class MultivariateNormalDiagHead(nn.Module):

action_dim: int
init_scale: float = 0.3
min_scale: float = 1e-6
kernel_init: Initializer = orthogonal(0.01)

@nn.compact
def __call__(self, embedding: chex.Array) -> distrax.DistributionLike:
loc = nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)
scale = jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding))
scale *= self.init_scale / jax.nn.softplus(0.0)
scale += self.min_scale

return TanhMultivariateNormalDiag(loc=loc, scale_diag=scale)
return MultivariateNormalDiag(loc=loc, scale_diag=scale)


class LinearOutputHead(nn.Module):
Expand All @@ -65,6 +90,66 @@ def __call__(self, embedding: chex.Array) -> chex.Array:
return nn.Dense(1, kernel_init=self.kernel_init)(embedding).squeeze(axis=-1)


class CategoricalCriticHead(nn.Module):

num_bins: int = 601
vmax: Optional[float] = None
vmin: Optional[float] = None
kernel_init: Initializer = orthogonal(1.0)

@nn.compact
def __call__(self, embedding: chex.Array) -> distrax.DistributionLike:
vmax = self.vmax if self.vmax is not None else 0.5 * (self.num_bins - 1)
vmin = self.vmin if self.vmin is not None else -1.0 * vmax

output = DiscreteValuedTfpHead(
vmin=vmin,
vmax=vmax,
logits_shape=(1,),
num_atoms=self.num_bins,
kernel_init=self.kernel_init,
)(embedding)

return output


class DiscreteValuedTfpHead(nn.Module):
"""Represents a parameterized discrete valued distribution.
The returned distribution is essentially a `tfd.Categorical` that knows its
support and thus can compute the mean value.
If vmin and vmax have shape S, this will store the category values as a
Tensor of shape (S*, num_atoms).
Args:
vmin: Minimum of the value range
vmax: Maximum of the value range
num_atoms: The atom values associated with each bin.
logits_shape: The shape of the logits, excluding batch and num_atoms
dimensions.
kernel_init: The initializer for the dense layer.
"""

vmin: float
vmax: float
num_atoms: int
logits_shape: Optional[Sequence[int]] = None
kernel_init: Initializer = lecun_normal()

def setup(self) -> None:
self._values = np.linspace(self.vmin, self.vmax, num=self.num_atoms, axis=-1)
if not self.logits_shape:
logits_shape = ()
self._logits_shape = logits_shape + (self.num_atoms,)
self._logits_size = np.prod(self._logits_shape)

def __call__(self, inputs: chex.Array) -> distrax.DistributionLike:
net = nn.Dense(self._logits_size, kernel_init=self.kernel_init)
logits = net(inputs)
logits = logits.reshape(logits.shape[:-1] + self._logits_shape)
return DiscreteValuedTfpDistribution(values=self._values, logits=logits)


class DiscreteQNetworkHead(nn.Module):
action_dim: int
epsilon: float = 0.1
Expand Down Expand Up @@ -131,6 +216,5 @@ def __call__(
q_logits = nn.Dense(self.num_atoms, kernel_init=self.kernel_init)(embedding)
q_dist = jax.nn.softmax(q_logits)
q_value = jnp.sum(q_dist * atoms, axis=-1)
# q_value = jax.lax.stop_gradient(q_value)
atoms = jnp.broadcast_to(atoms, (*q_value.shape, self.num_atoms))
return q_value, q_logits, atoms
Loading

0 comments on commit 93a9e83

Please sign in to comment.