Source code for ramsey._src.neural_process.attentive_neural_process
import flax
from chex import assert_axis_dimension
from flax import nnx
from flax.nnx import rnglib
from jax import numpy as jnp
from ramsey._src.family import Family, Gaussian
from ramsey._src.neural_process.neural_process import NP
from ramsey._src.nn.attention.attention import Attention
__all__ = ["ANP"]
# ruff: noqa: PLR0913
[docs]
class ANP(NP):
r"""An attentive neural process.
Implements the core structure of an attentive neural process
:cite:p:`kim2018attentive`.
Args:
decoder: the decoder can be any network, but is typically an MLP. Note
that the _last_ layer of the decoder needs to
have twice the number of nodes as the data you try to model
deterministic_encoder: a tuple of a flax.nnx.Module and an Attention obj
latent_encoder: an optional tuple of two flax.nnx.Modules. The latent
encoder can be any network, but is typically an MLP. The first element of
the tuple is a neural network used before the aggregation step, while the
second element of the tuple encodes is a neural network used to
compute mean(s) and standard deviation(s) of the latent Gaussian.
family: distributional family of the response variable
"""
def __init__(
self,
decoder: nnx.Module,
deterministic_encoder: tuple[flax.nnx.Module, Attention] | None = None,
latent_encoder: tuple[flax.nnx.Module, flax.nnx.Module] | None = None,
family: Family = Gaussian(),
*,
rngs: rnglib.Rngs | None = None,
):
super().__init__(
decoder, deterministic_encoder, latent_encoder, family, rngs=rngs
)
if latent_encoder is not None:
self._latent_encoder, self._latent_variable_encoder = (
latent_encoder[0],
latent_encoder[1],
)
if deterministic_encoder is not None:
self._deterministic_encoder = deterministic_encoder[0]
self._deterministic_cross_attention = deterministic_encoder[1]
@staticmethod
def _concat_and_tile(z_deterministic, z_latent, num_observations):
if z_latent is not None:
if z_latent.shape[1] == 1:
z_latent = jnp.tile(z_latent, [1, num_observations, 1])
representation = jnp.concatenate([z_deterministic, z_latent], axis=-1)
else:
representation = z_deterministic
assert_axis_dimension(representation, 1, num_observations)
return representation
def _encode_deterministic(self, x_context, y_context, x_target):
xy_context = jnp.concatenate([x_context, y_context], axis=-1)
z_deterministic = self._deterministic_encoder(xy_context)
z_deterministic = self._deterministic_cross_attention(
x_context, z_deterministic, x_target
)
return z_deterministic