From 93a9e833c446e895d03eaee0f6b00791c34739fa Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 19 Feb 2024 14:48:23 +0000 Subject: [PATCH] chore: change from distrax to tensorflow probability --- requirements/requirements.txt | 2 + stoix/configs/network/mlp_continuous.yaml | 2 +- stoix/configs/network/mlp_sac.yaml | 2 +- stoix/networks/base.py | 2 +- stoix/networks/distributions.py | 190 ++++++++++++++++++---- stoix/networks/heads.py | 106 ++++++++++-- stoix/networks/postprocessors.py | 20 ++- stoix/systems/d4pg/ff_d4pg.py | 6 +- stoix/systems/ppo/ff_ppo_continuous.py | 26 +-- stoix/systems/sac/ff_sac.py | 20 +-- stoix/types.py | 6 +- stoix/utils/loss.py | 12 +- 12 files changed, 313 insertions(+), 81 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 03263fbf..a7a10d2b 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -18,4 +18,6 @@ protobuf~=3.20 rlax tdqm tensorboard_logger +tensorflow +tensorflow_probability xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@main diff --git a/stoix/configs/network/mlp_continuous.yaml b/stoix/configs/network/mlp_continuous.yaml index e39178fa..bc43d246 100644 --- a/stoix/configs/network/mlp_continuous.yaml +++ b/stoix/configs/network/mlp_continuous.yaml @@ -6,7 +6,7 @@ actor_network: use_layer_norm: True activation: silu action_head: - _target_: stoix.networks.heads.TanhMultivariateNormalDiagHead + _target_: stoix.networks.heads.NormalTanhDistributionHead critic_network: pre_torso: diff --git a/stoix/configs/network/mlp_sac.yaml b/stoix/configs/network/mlp_sac.yaml index a178a66f..22fa8702 100644 --- a/stoix/configs/network/mlp_sac.yaml +++ b/stoix/configs/network/mlp_sac.yaml @@ -6,7 +6,7 @@ actor_network: use_layer_norm: False activation: silu action_head: - _target_: stoix.networks.heads.TanhMultivariateNormalDiagHead + _target_: stoix.networks.heads.NormalTanhDistributionHead q_network: input_layer: diff --git a/stoix/networks/base.py b/stoix/networks/base.py index 00535841..6b93ca34 100644 --- a/stoix/networks/base.py +++ b/stoix/networks/base.py @@ -123,7 +123,7 @@ def __call__( self, policy_hidden_state: chex.Array, observation_done: RNNObservation, - ) -> Tuple[chex.Array, distrax.Categorical]: + ) -> Tuple[chex.Array, distrax.DistributionLike]: """Forward pass.""" observation, done = observation_done diff --git a/stoix/networks/distributions.py b/stoix/networks/distributions.py index e5b64814..dbf775fc 100644 --- a/stoix/networks/distributions.py +++ b/stoix/networks/distributions.py @@ -1,41 +1,167 @@ -from typing import Sequence, Union +from typing import Any, Optional, Sequence -import jax +import chex import jax.numpy as jnp -from chex import Array, PRNGKey -from distrax import MultivariateNormalDiag +import numpy as np +import tensorflow_probability as tf_tfp +import tensorflow_probability.substrates.jax as tfp +from tensorflow_probability.substrates.jax.distributions import ( + Categorical, + Distribution, + MultivariateNormalDiag, + TransformedDistribution, +) -class TanhMultivariateNormalDiag(MultivariateNormalDiag): - """TanhMultivariateNormalDiag""" +class TanhTransformedDistribution(TransformedDistribution): + """Distribution followed by tanh.""" - def sample( - self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = () - ) -> Array: - """Sample from the distribution and apply the tanh.""" - sample = super().sample(seed=seed, sample_shape=sample_shape) - return jnp.tanh(sample) - - def sample_unprocessed( - self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = () - ) -> Array: - """Sample from the distribution without applying the tanh.""" - sample = super().sample(seed=seed, sample_shape=sample_shape) - return sample - - def log_prob_of_unprocessed(self, value: Array) -> Array: - """Log probability of a value in transformed distribution. - Value is the unprocessed value. i.e. the sample before the tanh.""" - log_prob = super().log_prob(value) - jnp.sum( - 2.0 * (jnp.log(2.0) - value - jax.nn.softplus(-2.0 * value)), - axis=-1, + def __init__( + self, distribution: Distribution, threshold: float = 0.999, validate_args: bool = False + ) -> None: + """Initialize the distribution. + + Args: + distribution: The distribution to transform. + threshold: Clipping value of the action when computing the logprob. + validate_args: Passed to super class. + """ + super().__init__( + distribution=distribution, bijector=tfp.bijectors.Tanh(), validate_args=validate_args + ) + # Computes the log of the average probability distribution outside the + # clipping range, i.e. on the interval [-inf, -atanh(threshold)] for + # log_prob_left and [atanh(threshold), inf] for log_prob_right. + self._threshold = threshold + inverse_threshold = self.bijector.inverse(threshold) + # average(pdf) = p/epsilon + # So log(average(pdf)) = log(p) - log(epsilon) + log_epsilon = jnp.log(1.0 - threshold) + # Those 2 values are differentiable w.r.t. model parameters, such that the + # gradient is defined everywhere. + self._log_prob_left = self.distribution.log_cdf(-inverse_threshold) - log_epsilon + self._log_prob_right = ( + self.distribution.log_survival_function(inverse_threshold) - log_epsilon ) - return log_prob + def log_prob(self, event: chex.Array) -> chex.Array: + # Without this clip there would be NaNs in the inner tf.where and that + # causes issues for some reasons. + event = jnp.clip(event, -self._threshold, self._threshold) + # The inverse image of {threshold} is the interval [atanh(threshold), inf] + # which has a probability of "log_prob_right" under the given distribution. + return jnp.where( + event <= -self._threshold, + self._log_prob_left, + jnp.where(event >= self._threshold, self._log_prob_right, super().log_prob(event)), + ) + + def mode(self) -> chex.Array: + return self.bijector.forward(self.distribution.mode()) + + def entropy(self, seed: chex.PRNGKey = None) -> chex.Array: + # We return an estimation using a single sample of the log_det_jacobian. + # We can still do some backpropagation with this estimate. + return self.distribution.entropy() + self.bijector.forward_log_det_jacobian( + self.distribution.sample(seed=seed), event_ndims=0 + ) + + @classmethod + def _parameter_properties(cls, dtype: Optional[Any], num_classes: Any = None) -> Any: + td_properties = super()._parameter_properties(dtype, num_classes=num_classes) + del td_properties["bijector"] + return td_properties + + +class DeterministicNormalDistribution(MultivariateNormalDiag): + """Deterministic normal distribution. Always returns the mean.""" -class DeterministicDistribution(MultivariateNormalDiag): def sample( - self, seed: Union[int, PRNGKey], sample_shape: Union[int, Sequence[int]] = () - ) -> Array: - sample = self.loc - return sample + self, + seed: chex.PRNGKey = None, + sample_shape: Sequence[int] = (), + name: str = "sample", + ) -> chex.Array: + return self.loc + + +@tf_tfp.experimental.auto_composite_tensor +class DiscreteValuedTfpDistribution(Categorical): + """This is a generalization of a categorical distribution. + + The support for the DiscreteValued distribution can be any real valued range, + whereas the categorical distribution has support [0, n_categories - 1] or + [1, n_categories]. This generalization allows us to take the mean of the + distribution over its support. + """ + + def __init__( + self, + values: chex.Array, + logits: Optional[chex.Array] = None, + probs: Optional[chex.Array] = None, + name: str = "DiscreteValuedDistribution", + ): + """Initialization. + + Args: + values: Values making up support of the distribution. Should have a shape + compatible with logits. + logits: An N-D Tensor, N >= 1, representing the log probabilities of a set + of Categorical distributions. The first N - 1 dimensions index into a + batch of independent distributions and the last dimension indexes into + the classes. + probs: An N-D Tensor, N >= 1, representing the probabilities of a set of + Categorical distributions. The first N - 1 dimensions index into a batch + of independent distributions and the last dimension represents a vector + of probabilities for each class. Only one of logits or probs should be + passed in. + name: Name of the distribution object. + """ + parameters = dict(locals()) + self._values = np.asarray(values) + + if logits is not None: + logits = jnp.asarray(logits) + chex.assert_shape(logits, (..., *self._values.shape)) + + if probs is not None: + probs = jnp.asarray(probs) + chex.assert_shape(probs, (..., *self._values.shape)) + + super().__init__(logits=logits, probs=probs, name=name) + + self._parameters = parameters + + @property + def values(self) -> chex.Array: + return self._values + + @classmethod + def _parameter_properties(cls, dtype: np.dtype, num_classes: Any = None) -> Any: + return { + "values": tfp.util.ParameterProperties( + event_ndims=None, shape_fn=lambda shape: (num_classes,), specifies_shape=True + ), + "logits": tfp.util.ParameterProperties(event_ndims=1), + "probs": tfp.util.ParameterProperties(event_ndims=1, is_preferred=False), + } + + def _sample_n(self, key: chex.PRNGKey, n: int) -> chex.Array: + indices = super()._sample_n(key=key, n=n) + return jnp.take_along_axis(self._values, indices, axis=-1) + + def mean(self) -> chex.Array: + """Overrides the Categorical mean by incorporating category values.""" + return jnp.sum(self.probs_parameter() * self._values, axis=-1) + + def variance(self) -> chex.Array: + """Overrides the Categorical variance by incorporating category values.""" + dist_squared = jnp.square(jnp.expand_dims(self.mean(), -1) - self._values) + return jnp.sum(self.probs_parameter() * dist_squared, axis=-1) + + def _event_shape(self) -> chex.Array: + return jnp.zeros((), dtype=jnp.int32) + + def _event_shape_tensor(self) -> chex.Array: + return [] diff --git a/stoix/networks/heads.py b/stoix/networks/heads.py index c3b36000..d5eb4a87 100644 --- a/stoix/networks/heads.py +++ b/stoix/networks/heads.py @@ -1,4 +1,4 @@ -from typing import Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import chex import distrax @@ -7,8 +7,17 @@ import numpy as np from flax import linen as nn from flax.linen.initializers import Initializer, lecun_normal, orthogonal +from tensorflow_probability.substrates.jax.distributions import ( + Categorical, + Independent, + MultivariateNormalDiag, + Normal, +) -from stoix.networks.distributions import TanhMultivariateNormalDiag +from stoix.networks.distributions import ( + DiscreteValuedTfpDistribution, + TanhTransformedDistribution, +) class CategoricalHead(nn.Module): @@ -16,33 +25,49 @@ class CategoricalHead(nn.Module): kernel_init: Initializer = orthogonal(0.01) @nn.compact - def __call__(self, embedding: chex.Array) -> distrax.Categorical: + def __call__(self, embedding: chex.Array) -> Categorical: logits = nn.Dense(np.prod(self.action_dim), kernel_init=self.kernel_init)(embedding) if not isinstance(self.action_dim, int): logits = logits.reshape(self.action_dim) - return distrax.Categorical(logits=logits) + return Categorical(logits=logits) -class TanhMultivariateNormalDiagHead(nn.Module): +class NormalTanhDistributionHead(nn.Module): action_dim: int - init_scale: float = 0.3 min_scale: float = 1e-3 kernel_init: Initializer = orthogonal(0.01) @nn.compact - def __call__(self, embedding: chex.Array) -> TanhMultivariateNormalDiag: + def __call__(self, embedding: chex.Array) -> Independent: loc = nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding) - scale = jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)) + scale = ( + jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)) + + self.min_scale + ) + distribution = Normal(loc=loc, scale=scale) + + return Independent(TanhTransformedDistribution(distribution), reinterpreted_batch_ndims=1) + + +class MultivariateNormalDiagHead(nn.Module): + action_dim: int + init_scale: float = 0.3 + min_scale: float = 1e-6 + kernel_init: Initializer = orthogonal(0.01) + + @nn.compact + def __call__(self, embedding: chex.Array) -> distrax.DistributionLike: + loc = nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding) + scale = jax.nn.softplus(nn.Dense(self.action_dim, kernel_init=self.kernel_init)(embedding)) scale *= self.init_scale / jax.nn.softplus(0.0) scale += self.min_scale - - return TanhMultivariateNormalDiag(loc=loc, scale_diag=scale) + return MultivariateNormalDiag(loc=loc, scale_diag=scale) class LinearOutputHead(nn.Module): @@ -65,6 +90,66 @@ def __call__(self, embedding: chex.Array) -> chex.Array: return nn.Dense(1, kernel_init=self.kernel_init)(embedding).squeeze(axis=-1) +class CategoricalCriticHead(nn.Module): + + num_bins: int = 601 + vmax: Optional[float] = None + vmin: Optional[float] = None + kernel_init: Initializer = orthogonal(1.0) + + @nn.compact + def __call__(self, embedding: chex.Array) -> distrax.DistributionLike: + vmax = self.vmax if self.vmax is not None else 0.5 * (self.num_bins - 1) + vmin = self.vmin if self.vmin is not None else -1.0 * vmax + + output = DiscreteValuedTfpHead( + vmin=vmin, + vmax=vmax, + logits_shape=(1,), + num_atoms=self.num_bins, + kernel_init=self.kernel_init, + )(embedding) + + return output + + +class DiscreteValuedTfpHead(nn.Module): + """Represents a parameterized discrete valued distribution. + + The returned distribution is essentially a `tfd.Categorical` that knows its + support and thus can compute the mean value. + If vmin and vmax have shape S, this will store the category values as a + Tensor of shape (S*, num_atoms). + + Args: + vmin: Minimum of the value range + vmax: Maximum of the value range + num_atoms: The atom values associated with each bin. + logits_shape: The shape of the logits, excluding batch and num_atoms + dimensions. + kernel_init: The initializer for the dense layer. + """ + + vmin: float + vmax: float + num_atoms: int + logits_shape: Optional[Sequence[int]] = None + kernel_init: Initializer = lecun_normal() + + def setup(self) -> None: + self._values = np.linspace(self.vmin, self.vmax, num=self.num_atoms, axis=-1) + if not self.logits_shape: + logits_shape = () + self._logits_shape = logits_shape + (self.num_atoms,) + self._logits_size = np.prod(self._logits_shape) + + def __call__(self, inputs: chex.Array) -> distrax.DistributionLike: + net = nn.Dense(self._logits_size, kernel_init=self.kernel_init) + logits = net(inputs) + logits = logits.reshape(logits.shape[:-1] + self._logits_shape) + return DiscreteValuedTfpDistribution(values=self._values, logits=logits) + + class DiscreteQNetworkHead(nn.Module): action_dim: int epsilon: float = 0.1 @@ -131,6 +216,5 @@ def __call__( q_logits = nn.Dense(self.num_atoms, kernel_init=self.kernel_init)(embedding) q_dist = jax.nn.softmax(q_logits) q_value = jnp.sum(q_dist * atoms, axis=-1) - # q_value = jax.lax.stop_gradient(q_value) atoms = jnp.broadcast_to(atoms, (*q_value.shape, self.num_atoms)) return q_value, q_logits, atoms diff --git a/stoix/networks/postprocessors.py b/stoix/networks/postprocessors.py index 5a9dc4a7..d9112123 100644 --- a/stoix/networks/postprocessors.py +++ b/stoix/networks/postprocessors.py @@ -1,13 +1,28 @@ +import chex import jax import jax.numpy as jnp from flax import linen as nn +class RescaleToSpec(nn.Module): + minimum: float + maximum: float + + @nn.compact + def __call__(self, inputs: chex.Array) -> chex.Array: + scale = self.maximum - self.minimum + offset = self.minimum + inputs = 0.5 * (inputs + 1.0) # [0, 1] + output = inputs * scale + offset # [minimum, maximum] + return output + + class TanhToSpec(nn.Module): minimum: float maximum: float - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + @nn.compact + def __call__(self, inputs: chex.Array) -> chex.Array: scale = self.maximum - self.minimum offset = self.minimum inputs = jax.nn.tanh(inputs) # [-1, 1] @@ -20,6 +35,7 @@ class ClipToSpec(nn.Module): minimum: float maximum: float - def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: + @nn.compact + def __call__(self, inputs: chex.Array) -> chex.Array: output = jnp.clip(inputs, self.minimum, self.maximum) return output diff --git a/stoix/systems/d4pg/ff_d4pg.py b/stoix/systems/d4pg/ff_d4pg.py index 9e26123a..53aa5d08 100644 --- a/stoix/systems/d4pg/ff_d4pg.py +++ b/stoix/systems/d4pg/ff_d4pg.py @@ -4,7 +4,7 @@ import rlax -from stoix.networks.distributions import DeterministicDistribution +from stoix.networks.distributions import DeterministicNormalDistribution if TYPE_CHECKING: from dataclasses import dataclass @@ -497,9 +497,9 @@ def learner_setup( class EvalActorWrapper: actor: Actor - def apply(self, params: FrozenDict, x: Observation) -> DeterministicDistribution: + def apply(self, params: FrozenDict, x: Observation) -> DeterministicNormalDistribution: action = self.actor.apply(params, x) - return DeterministicDistribution(loc=action) + return DeterministicNormalDistribution(loc=action) def run_experiment(_config: DictConfig) -> None: diff --git a/stoix/systems/ppo/ff_ppo_continuous.py b/stoix/systems/ppo/ff_ppo_continuous.py index 35e38590..f3156d80 100644 --- a/stoix/systems/ppo/ff_ppo_continuous.py +++ b/stoix/systems/ppo/ff_ppo_continuous.py @@ -76,11 +76,11 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) value = critic_apply_fn(params.critic_params, last_timestep.observation) - action = actor_policy.sample_unprocessed(seed=policy_key) - log_prob = actor_policy.log_prob_of_unprocessed(action) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) # STEP ENVIRONMENT - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, jnp.tanh(action)) + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS done = timestep.last().reshape(-1) @@ -114,7 +114,7 @@ def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: """Update the network for a single minibatch.""" # UNPACK TRAIN STATE AND BATCH INFO - params, opt_states = train_state + params, opt_states, key = train_state traj_batch, advantages, targets = batch_info def _actor_loss_fn( @@ -122,11 +122,12 @@ def _actor_loss_fn( actor_opt_state: OptState, traj_batch: PPOTransition, gae: chex.Array, + rng_key: chex.PRNGKey, ) -> Tuple: """Calculate the actor loss.""" # RERUN NETWORK actor_policy = actor_apply_fn(actor_params, traj_batch.obs) - log_prob = actor_policy.log_prob_of_unprocessed(traj_batch.action) + log_prob = actor_policy.log_prob(traj_batch.action) # CALCULATE ACTOR LOSS ratio = jnp.exp(log_prob - traj_batch.log_prob) @@ -142,7 +143,7 @@ def _actor_loss_fn( ) loss_actor = -jnp.minimum(loss_actor1, loss_actor2) loss_actor = loss_actor.mean() - entropy = actor_policy.entropy().mean() + entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = loss_actor - config.system.ent_coef * entropy return total_loss_actor, (loss_actor, entropy) @@ -169,9 +170,14 @@ def _critic_loss_fn( return critic_total_loss, (value_loss) # CALCULATE ACTOR LOSS + key, actor_loss_key = jax.random.split(key) actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) actor_loss_info, actor_grads = actor_grad_fn( - params.actor_params, opt_states.actor_opt_state, traj_batch, advantages + params.actor_params, + opt_states.actor_opt_state, + traj_batch, + advantages, + actor_loss_key, ) # CALCULATE CRITIC LOSS @@ -227,7 +233,7 @@ def _critic_loss_fn( "actor_loss": actor_loss, "entropy": entropy, } - return (new_params, new_opt_state), loss_info + return (new_params, new_opt_state, key), loss_info params, opt_states, traj_batch, advantages, targets, key = update_state key, shuffle_key = jax.random.split(key) @@ -246,8 +252,8 @@ def _critic_loss_fn( ) # UPDATE MINIBATCHES - (params, opt_states), loss_info = jax.lax.scan( - _update_minibatch, (params, opt_states), minibatches + (params, opt_states, key), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, key), minibatches ) update_state = (params, opt_states, traj_batch, advantages, targets, key) diff --git a/stoix/systems/sac/ff_sac.py b/stoix/systems/sac/ff_sac.py index 52a0e2df..9f104d1f 100644 --- a/stoix/systems/sac/ff_sac.py +++ b/stoix/systems/sac/ff_sac.py @@ -56,10 +56,10 @@ def _env_step( # SELECT ACTION key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) - action = actor_policy.sample_unprocessed(seed=policy_key) + action = actor_policy.sample(seed=policy_key) # STEP ENVIRONMENT - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, jnp.tanh(action)) + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS done = timestep.last().reshape(-1) @@ -112,10 +112,10 @@ def _env_step(learner_state: SACLearnerState, _: Any) -> Tuple[SACLearnerState, key, policy_key = jax.random.split(key) actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) - action = actor_policy.sample_unprocessed(seed=policy_key) + action = actor_policy.sample(seed=policy_key) # STEP ENVIRONMENT - env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, jnp.tanh(action)) + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) # LOG EPISODE METRICS done = timestep.last().reshape(-1) @@ -152,8 +152,8 @@ def _alpha_loss_fn( ) -> jnp.ndarray: """Eq 18 from https://arxiv.org/pdf/1812.05905.pdf.""" actor_policy = actor_apply_fn(actor_params, transitions.obs) - action = actor_policy.sample_unprocessed(seed=key) - log_prob = actor_policy.log_prob_of_unprocessed(action) + action = actor_policy.sample(seed=key) + log_prob = actor_policy.log_prob(action) alpha = jnp.exp(log_alpha) alpha_loss = alpha * jax.lax.stop_gradient(-log_prob - config.system.target_entropy) @@ -173,8 +173,8 @@ def _q_loss_fn( ) -> jnp.ndarray: q_old_action = q_apply_fn(q_params, transitions.obs, transitions.action) next_actor_policy = actor_apply_fn(actor_params, transitions.next_obs) - next_action = next_actor_policy.sample_unprocessed(seed=key) - next_log_prob = next_actor_policy.log_prob_of_unprocessed(next_action) + next_action = next_actor_policy.sample(seed=key) + next_log_prob = next_actor_policy.log_prob(next_action) next_q = q_apply_fn(target_q_params, transitions.next_obs, next_action) next_v = jnp.min(next_q, axis=-1) - alpha * next_log_prob target_q = jax.lax.stop_gradient( @@ -197,8 +197,8 @@ def _actor_loss_fn( key: chex.PRNGKey, ) -> chex.Array: actor_policy = actor_apply_fn(actor_params, transitions.obs) - action = actor_policy.sample_unprocessed(seed=key) - log_prob = actor_policy.log_prob_of_unprocessed(action) + action = actor_policy.sample(seed=key) + log_prob = actor_policy.log_prob(action) q_action = q_apply_fn(q_params, transitions.obs, action) min_q = jnp.min(q_action, axis=-1) actor_loss = alpha * log_prob - min_q diff --git a/stoix/types.py b/stoix/types.py index 530a8175..4e129a93 100644 --- a/stoix/types.py +++ b/stoix/types.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Tuple, TypeVar import chex -from distrax import Distribution +from distrax import DistributionLike from flax.core.frozen_dict import FrozenDict from jumanji.types import TimeStep from typing_extensions import NamedTuple, TypeAlias @@ -96,9 +96,9 @@ class ExperimentOutput(NamedTuple, Generic[StoixState]): LearnerFn = Callable[[StoixState], ExperimentOutput[StoixState]] EvalFn = Callable[[FrozenDict, chex.PRNGKey], ExperimentOutput[StoixState]] -ActorApply = Callable[[FrozenDict, Observation], Distribution] +ActorApply = Callable[[FrozenDict, Observation], DistributionLike] CriticApply = Callable[[FrozenDict, Observation], Value] RecActorApply = Callable[ - [FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Distribution] + [FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, DistributionLike] ] RecCriticApply = Callable[[FrozenDict, HiddenState, RNNObservation], Tuple[HiddenState, Value]] diff --git a/stoix/utils/loss.py b/stoix/utils/loss.py index 28bf04d2..ede0f1e7 100644 --- a/stoix/utils/loss.py +++ b/stoix/utils/loss.py @@ -1,8 +1,10 @@ import chex -import distrax import jax import jax.numpy as jnp import rlax +import tensorflow_probability.substrates.jax as tfp + +tfd = tfp.distributions def ppo_loss( @@ -57,9 +59,7 @@ def categorical_double_q_learning( target = jax.vmap(rlax.categorical_l2_project)(target_z, p_target_z, q_atoms_tm1) # Compute loss (i.e. temporal difference error). logit_qa_tm1 = q_logits_tm1[batch_indices, a_tm1] - td_error = distrax.Categorical(probs=target).cross_entropy( - distrax.Categorical(logits=logit_qa_tm1) - ) + td_error = tfd.Categorical(probs=target).cross_entropy(tfd.Categorical(logits=logit_qa_tm1)) q_loss = jnp.mean(td_error) return q_loss @@ -103,8 +103,6 @@ def categorical_td_learning( # Project using the Cramer distance and maybe stop gradient flow to targets. target = jax.vmap(rlax.categorical_l2_project)(target_z, v_t_probs, v_atoms_tm1) - td_error = distrax.Categorical(probs=target).cross_entropy( - distrax.Categorical(logits=v_logits_tm1) - ) + td_error = tfd.Categorical(probs=target).cross_entropy(tfd.Categorical(logits=v_logits_tm1)) return jnp.mean(td_error)