Neural processes¶
We'll demonstrate training a neural process and an attentive neural process model using a batch of draws from a Gaussian process. First load some libraries.
%matplotlib inline
import fybdthemes
import matplotlib.pyplot as plt
from jax import numpy as jnp
from jax import random as jr
from matplotlib import rc
from ramsey import train_neural_process
from ramsey.data import sample_from_gaussian_process
from ramsey.nn import MLP
fybdthemes.set_theme()
rc("font", **{"family": "sans-serif", "sans-serif": ["Arial Narrow"]})
Next we sample some data from a Gaussian process:
rng_key = jr.PRNGKey(1)
sample_key, rng_key = jr.split(rng_key)
data = sample_from_gaussian_process(sample_key, batch_size=10, num_observations=200)
(x_target, y_target), f_target = (data.x, data.y), data.f
Visualize the batch of draws:
cols = fybdthemes.discrete_sequential_colors(10)
_, ax = plt.subplots(figsize=(8, 4))
for i in range(x_target.shape[0]):
x = jnp.squeeze(x_target[i, :, :])
y = jnp.squeeze(y_target[i, :, :])
f = jnp.squeeze(f_target[i, :, :])
idxs = jnp.argsort(x)
ax.plot(x[idxs], f[idxs], color=cols[i], alpha=0.5)
ax.scatter(x[idxs], y[idxs], marker="+", color=cols[i], alpha=0.5)
plt.show()
Neural process¶
Having sampled a data set, we can define the model. A neural process model takes a decoder
argument, a latent_encoder
argument and an optional deterministic_encoder
argument.
The
decoder
is a neural network, typically an MLP, that maps the latent representation to the data space, or, more specifically, is used to parameterize an exponential family distribution. The decoder needs to respect the dimensionality of the response to be modelled. For instance, if the response is a univariate Gaussian the last layer needs to have two neurons: one do parameterize the mean and one to parameterize the variance of the Gaussian.The
latent_encoder
is a tuple of two neural networks, typically also two MLPs. The first of which maps the data to a latent representation. This representation gets averaged and then passed through the second MLP to parameterize a latent Gaussian.The
deterministic_encoder
is typically also an MLP that maps the data two a deterministic univariate representation.
Below we create a simple NP model with a decoder and a latent encoder. All networks are MLPs. The decoder has two neurons in its final layer (for mean and variance of the Gaussian that models the data), the encoder's final layer is 2*128 nodes to parameterize a latent 128 dimensional Gaussian.
from ramsey import NP
def get_neural_process():
dim = 128
np = NP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),
)
return np
neural_process = get_neural_process()
We use 10 context points and 20 target points per iteration of training:
train_key, rng_key = jr.split(jr.PRNGKey(0))
n_context, n_target = 10, 20
params, _ = train_neural_process(
train_key,
neural_process,
x=x_target,
y=y_target,
n_context=n_context,
n_target=n_target,
n_iter=10_000,
batch_size=2,
)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:43<00:00, 229.88it/s]
That's it! Finally, we define a set of data points for a batch and make some predictions:
sample_key, rng_key = jr.split(rng_key)
sample_idxs = jr.choice(
sample_key, x_target.shape[1], shape=(n_target + n_context,), replace=False
)
idxs = [0, 2, 5, 7]
_, axes = plt.subplots(figsize=(10, 6), nrows=2, ncols=2)
for _, (idx, ax) in enumerate(zip(idxs, axes.flatten())):
x = jnp.squeeze(x_target[idx, :, :])
f = jnp.squeeze(f_target[idx, :, :])
y = jnp.squeeze(y_target[idx, :, :])
srt_idxs = jnp.argsort(x)
ax.plot(x[srt_idxs], f[srt_idxs], color="black", alpha=0.75)
ax.scatter(
x[sample_idxs[:n_context]],
y[sample_idxs[:n_context]],
color="black",
marker="+",
alpha=0.75,
)
for _ in range(20):
sample_rng_key, rng_key = jr.split(rng_key, 2)
y_star = neural_process.apply(
variables=params,
rngs={"sample": sample_rng_key},
x_context=x[jnp.newaxis, sample_idxs, jnp.newaxis],
y_context=y[jnp.newaxis, sample_idxs, jnp.newaxis],
x_target=x_target[[idx], :, :],
).mean
x_star = jnp.squeeze(x_target[[idx], :, :])
y_star = jnp.squeeze(y_star)
ax.plot(x_star[srt_idxs], y_star[srt_idxs], color="#960C1A", alpha=0.1)
plt.show()
Attentive neural process¶
An attentive neural process model takes a decoder
argument, a latent_encoder
argument and a deterministic_encoder
argument with an attention module. In comparison to before the
deterministic_encoder
is also a tuple in this case. In addition to the network that generates the representation, one is also requires to provide a mechanism for cross-attention.
Below we create an ANP model with a decoder a latent encoder and a deterministic encoder with MultiHeadAttention
. The encoder itself is again an MLP, the MultiHeadAttention
module takes also an MLP that embeds the keys and query into a higher dimensional space before attending to the query.
from ramsey import ANP
from ramsey.nn import MultiHeadAttention
def get_attentive_neural_process():
dim = 128
np = ANP(
decoder=MLP([dim] * 3 + [2]),
latent_encoder=(MLP([dim] * 3), MLP([dim, dim * 2])),
deterministic_encoder=(
MLP([dim] * 3),
MultiHeadAttention(num_heads=4, head_size=16, embedding=MLP([dim] * 2)),
),
)
return np
attentive_neural_process = get_attentive_neural_process()
After the setup, the initialization of parameters and training is identical.
train_key, rng_key = jr.split(jr.PRNGKey(0))
params, _ = train_neural_process(
train_key,
attentive_neural_process,
x=x_target,
y=y_target,
n_context=n_context,
n_target=n_target,
n_iter=10_000,
batch_size=2,
)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:52<00:00, 191.85it/s]
We again visualize some predictions.
idxs = [0, 2, 5, 7]
_, axes = plt.subplots(figsize=(10, 6), nrows=2, ncols=2)
for _, (idx, ax) in enumerate(zip(idxs, axes.flatten())):
x = jnp.squeeze(x_target[idx, :, :])
f = jnp.squeeze(f_target[idx, :, :])
y = jnp.squeeze(y_target[idx, :, :])
srt_idxs = jnp.argsort(x)
ax.plot(x[srt_idxs], f[srt_idxs], color="black", alpha=0.75)
ax.scatter(
x[sample_idxs[:n_context]],
y[sample_idxs[:n_context]],
color="black",
marker="+",
alpha=0.75,
)
for _ in range(20):
sample_rng_key, rng_key = jr.split(rng_key, 2)
y_star = attentive_neural_process.apply(
variables=params,
rngs={"sample": sample_rng_key},
x_context=x[jnp.newaxis, sample_idxs, jnp.newaxis],
y_context=y[jnp.newaxis, sample_idxs, jnp.newaxis],
x_target=x_target[[idx], :, :],
).mean
x_star = jnp.squeeze(x_target[[idx], :, :])
y_star = jnp.squeeze(y_star)
ax.plot(x_star[srt_idxs], y_star[srt_idxs], color="#960C1A", alpha=0.1)
plt.show()
We can see that the predictions have significantly less variance as before.
Session info¶
import session_info
session_info.show(html=False)
----- fybdthemes 0.1.2 jax 0.4.19 jaxlib 0.4.19 matplotlib 3.8.0 ramsey 0.2.0 seaborn 0.11.2 session_info 1.0.0 ----- IPython 8.16.1 jupyter_client 8.4.0 jupyter_core 5.4.0 jupyterlab 4.0.7 notebook 7.0.6 ----- Python 3.11.6 | packaged by conda-forge | (main, Oct 3 2023, 10:37:07) [Clang 15.0.7 ] macOS-13.0.1-arm64-arm-64bit ----- Session information updated at 2023-10-23 13:56
References¶
[1] Garnelo, Marta, et al. "Conditional neural processes." ICML, 2018.
[2] Garnelo, Marta, et al. "Neural processes." ICML Workshop on Theoretical Foundations and Applications of Deep Generative Model, 2018.
[3] Kim, Hyunjik, et al. "Attentive neural processes." ICLR, 2019.