Source code for ramsey._src.neural_process.doubly_attentive_neural_process
import flax
import jax
from flax import nnx
from flax.nnx import rnglib
from jax import numpy as jnp
from ramsey._src.family import Family, Gaussian
from ramsey._src.neural_process.attentive_neural_process import ANP
from ramsey._src.nn.attention.attention import Attention
__all__ = ["DANP"]
# ruff: noqa: PLR0913
[docs]
class DANP(ANP):
r"""A doubly-attentive neural process.
Implements the core structure of a 'doubly-attentive' neural process
:cite:p:`kim2018attentive`.
Args:
decoder: the decoder can be any network, but is typically an MLP. Note
that the _last_ layer of the decoder needs to
have twice the number of nodes as the data you try to model
latent_encoder: a tuple of two flax.nnx.Modules and an attention object.
The first and last elements are the usual modules required for a
neural process, the attention object computes self-attention before the
aggregation
deterministic_encoder: a tuple of a`flax.nnx.Module and an Attention
object. The first Attention object is used for self-attention,
the second one is used for cross-attention
family: distributional family of the response variable
"""
def __init__(
self,
decoder: nnx.Module,
deterministic_encoder: tuple[flax.nnx.Module, Attention, Attention]
| None = None,
latent_encoder: tuple[flax.nnx.Module, Attention, flax.nnx.Module]
| None = None,
family: Family = Gaussian(),
*,
rngs: rnglib.Rngs | None = None,
):
"""Construct all networks."""
super().__init__(
decoder,
deterministic_encoder, # type: ignore[arg-type]
latent_encoder, # type: ignore[arg-type]
family,
rngs=rngs,
)
if latent_encoder is not None:
(
self._latent_encoder,
self._latent_self_attention,
self._latent_variable_encoder,
) = latent_encoder
if deterministic_encoder is not None:
(
self._deterministic_encoder,
self._deterministic_self_attention,
self._deterministic_cross_attention,
) = deterministic_encoder # type: ignore[var-annotated]
def _encode_latent(self, x_context: jax.Array, y_context: jax.Array):
xy_context = jnp.concatenate([x_context, y_context], axis=-1)
z_latent = self._latent_encoder(xy_context) # type: ignore[operator, misc]
z_latent = self._latent_self_attention(z_latent, z_latent, z_latent)
return self._encode_latent_gaussian(z_latent)
def _encode_deterministic(
self,
x_context: jax.Array,
y_context: jax.Array,
x_target: jax.Array,
):
xy_context = jnp.concatenate([x_context, y_context], axis=-1)
z_deterministic = self._deterministic_encoder(xy_context) # type: ignore[misc]
z_deterministic = self._deterministic_self_attention(
z_deterministic, z_deterministic, z_deterministic
)
z_deterministic = self._deterministic_cross_attention(
x_context, z_deterministic, x_target
)
return z_deterministic