Source code for ramsey._src.neural_process.neural_process

import flax
import jax
import numpyro.distributions
import numpyro.distributions as dist
from chex import assert_axis_dimension, assert_rank
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.module import first_from
from jax import numpy as jnp
from numpyro.distributions import kl_divergence

from ramsey._src.family import Family, Gaussian

__all__ = ["NP"]


[docs] class NP(nnx.Module): r"""A neural process. Implements the core structure of a vanilla (latent) neural process :cite:p:`garnelo18conditional,garnelo2018neural`. Args: decoder: the decoder can be any network, but is typically an MLP. Note that the _last_ layer of the decoder needs to have twice the number of nodes as the data you try to model latent_encoder: the latent encoder can be any network, but is typically an MLP. The first element of the tuple is a neural network used before the aggregation step, while the second element of the tuple encodes is a neural network used to compute mean(s) and standard deviation(s) of the latent Gaussian. deterministic_encoder: the deterministic encoder can be any network, but is typically an MLP family: distributional family of the response variable """ def __init__( self, decoder: nnx.Module, deterministic_encoder: flax.nnx.Module | None = None, latent_encoder: tuple[flax.nnx.Module, flax.nnx.Module] | None = None, family: Family = Gaussian(), *, rngs: rnglib.Rngs | None = None, ): self.rngs = rngs self._decoder = decoder self._family = family self._latent_encoder = latent_encoder self._deterministic_encoder = deterministic_encoder if latent_encoder is None and deterministic_encoder is None: raise ValueError("either latent or deterministic encoder needs to be set") if latent_encoder is not None: self._latent_encoder, self._latent_variable_encoder = ( latent_encoder[0], latent_encoder[1], )
[docs] def __call__( self, x_context: jax.Array, y_context: jax.Array, x_target: jax.Array, *, rngs: rnglib.Rngs | None = None, ) -> numpyro.distributions.Distribution: """Transform inputs through the neural process. Args: x_context: context input data of dimension (batch_dim, spatial_dims..., feature_dim) y_context: context output data of dimension (batch_dim, spatial_dims..., response_dim) x_target: target input data of dimension (batch_dim, spatial_dims..., feature_dim) rngs: a rnglib.Rngs object for random seeds Returns: returns the predictive distribution of y_target """ assert_rank([x_context, y_context, x_target], 3) _, num_observations, _ = x_target.shape if self._latent_encoder is not None: rngs = first_from( rngs, self.rngs, error_msg="no 'rngs' argument provided" ) rng = rngs["sample"]() z_latent = self._encode_latent(x_context, y_context).sample(rng) else: z_latent = None z_deterministic = self._encode_deterministic(x_context, y_context, x_target) representation = self._concat_and_tile( z_deterministic, z_latent, num_observations ) pred_fn = self._decode(representation, x_target, y_context) return pred_fn
[docs] def loss( self, x_context: jax.Array, y_context: jax.Array, x_target: jax.Array, y_target: jax.Array, *, rngs: rnglib.Rngs | None = None, ) -> jax.Array: """Compute the loss for a set of input-output pairs. The loss is computed by approximating the marginal likelihood via a lower bound (the ELBO) of which the negative is returned. Args: x_context: context input data of dimension (batch_dim, spatial_dims..., feature_dim) y_context: context output data of dimension (batch_dim, spatial_dims..., response_dim) x_target: target input data of dimension (batch_dim, spatial_dims..., feature_dim) y_target: target output data of dimension (batch_dim, spatial_dims..., response_dim) rngs: a rnglib.Rngs object for random seeds Returns: returns the negative ELBO """ _, num_observations, _ = x_target.shape if self._latent_encoder is not None: rngs = first_from( rngs, self.rngs, error_msg="no 'rngs' argument provided" ) rng = rngs["sample"]() prior = self._encode_latent(x_context, y_context) posterior = self._encode_latent(x_target, y_target) z_latent = posterior.sample(rng) kl = jnp.sum(kl_divergence(posterior, prior), axis=-1) else: z_latent = None kl = 0 z_deterministic = self._encode_deterministic(x_context, y_context, x_target) representation = self._concat_and_tile( z_deterministic, z_latent, num_observations ) pred_fn = self._decode(representation, x_target, y_target) loglik = jnp.sum(pred_fn.log_prob(y_target), axis=1) elbo = jnp.mean(loglik - kl) return -elbo
@staticmethod def _concat_and_tile(z_deterministic, z_latent, num_observations): if z_deterministic is None: representation = z_latent elif z_latent is None: representation = z_deterministic else: representation = jnp.concatenate([z_deterministic, z_latent], axis=-1) assert_axis_dimension(representation, 1, 1) representation = jnp.tile(representation, [1, num_observations, 1]) assert_axis_dimension(representation, 1, num_observations) return representation def _encode_deterministic( self, x_context: jax.Array, y_context: jax.Array, x_target: jax.Array, ): if self._deterministic_encoder is None: return None xy_context = jnp.concatenate([x_context, y_context], axis=-1) z_deterministic = self._deterministic_encoder(xy_context) z_deterministic = jnp.mean(z_deterministic, axis=1, keepdims=True) return z_deterministic def _encode_latent(self, x_context: jax.Array, y_context: jax.Array): xy_context = jnp.concatenate([x_context, y_context], axis=-1) z_latent = self._latent_encoder(xy_context) # type: ignore[operator,misc] return self._encode_latent_gaussian(z_latent) def _encode_latent_gaussian(self, z_latent): z_latent = jnp.mean(z_latent, axis=1, keepdims=True) z_latent = self._latent_variable_encoder(z_latent) mean, sigma = jnp.split(z_latent, 2, axis=-1) sigma = 0.1 + 0.9 * jax.nn.sigmoid(sigma) return dist.Normal(loc=mean, scale=sigma) def _decode( self, representation: jax.Array, x_target: jax.Array, y: jax.Array ): target = jnp.concatenate([representation, x_target], axis=-1) target = self._decoder(target) family = self._family(target) self._check_posterior_predictive_axis(family, x_target, y) return family @staticmethod def _check_posterior_predictive_axis( family: dist.Distribution, x_target: jax.Array, y: jax.Array, # pylint: disable=invalid-name ): assert_axis_dimension(family.mean, 0, x_target.shape[0]) assert_axis_dimension(family.mean, 1, x_target.shape[1]) assert_axis_dimension(family.mean, 2, y.shape[2])