Skip to content

ramsey

Module containing all implemented probabilistic models and training functions.

Models

ramsey.NP

Bases: Module

A neural process.

Implements the core structure of a neural process [1], i.e., an optional deterministic encoder, a latent encoder, and a decoder.

Attributes:

Name Type Description
decoder Module

the decoder can be any network, but is typically an MLP. Note that the last layer of the decoder needs to have twice the number of nodes as the data you try to model

latent_encoder Optional[Tuple[Module, Module]]

a tuple of two flax.linen.Modules. The latent encoder can be any network, but is typically an MLP. The first element of the tuple is a neural network used before the aggregation step, while the second element of the tuple encodes is a neural network used to compute mean(s) and standard deviation(s) of the latent Gaussian.

deterministic_encoder Optional[Module]

the deterministic encoder can be any network, but is typically an MLP

family Family

distributional family of the response variable

References

[1] Garnelo, Marta, et al. "Neural processes". CoRR. 2018.

__call__(x_context: Array, y_context: Array, x_target: Array, **kwargs)

Transform the inputs through the neural process.

Parameters:

Name Type Description Default
x_context Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
y_context Array

Input data of dimension (*batch_dims, spatial_dims..., response_dims)

required
x_target Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
**kwargs

Keyword arguments can include: - y_target: jax.Array. If an argument called 'y_target' is provided, computes the loss (negative ELBO) together with a predictive posterior distribution

{}

Returns:

Type Description
Union[distribution, Tuple[distribution, float]]

If 'y_target' is provided as keyword argument, returns a tuple of the predictive distribution and the negative ELBO which can be used as loss for optimization. If 'y_target' is not provided, returns the predictive distribution only.

ramsey.ANP

Bases: NP

An attentive neural process.

Implements the core structure of an attentive neural process [1], i.e., a deterministic encoder, a latent encoder, and a decoder with a cross-attention module.

Attributes:

Name Type Description
decoder Module

the decoder can be any network, but is typically an MLP. Note that the last layer of the decoder needs to have twice the number of nodes as the data you try to model

latent_encoders Optional[Tuple[Module, Module]]

a tuple of two flax.linen.Modules. The latent encoder can be any network, but is typically an MLP. The first element of the tuple is a neural network used before the aggregation step, while the second element of the tuple encodes is a neural network used to compute mean(s) and standard deviation(s) of the latent Gaussian.

deterministic_encoder Optional[Tuple[Module, Attention]]

a tuple of a flax.linen.Module and an Attention object. The deterministic encoder can be any network, but is typically an MLP

family Family

distributional family of the response variable

References

.. [1] Kim, Hyunjik, et al. "Attentive Neural Processes." International Conference on Learning Representations. 2019.

__call__(x_context: Array, y_context: Array, x_target: Array, **kwargs)

Transform the inputs through the neural process.

Parameters:

Name Type Description Default
x_context Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
y_context Array

Input data of dimension (*batch_dims, spatial_dims..., response_dims)

required
x_target Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
**kwargs

Keyword arguments can include: - y_target: jax.Array. If an argument called 'y_target' is provided, computes the loss (negative ELBO) together with a predictive posterior distribution

{}

Returns:

Type Description
Union[distribution, Tuple[distribution, float]]

If 'y_target' is provided as keyword argument, returns a tuple of the predictive distribution and the negative ELBO which can be used as loss for optimization. If 'y_target' is not provided, returns the predictive distribution only.

ramsey.DANP

Bases: ANP

A doubly-attentive neural process.

Implements the core structure of a 'doubly-attentive' neural process [1], i.e., a deterministic encoder, a latent encoder with self-attention module, and a decoder with both self- and cross-attention modules.

Attributes:

Name Type Description
decoder Module

the decoder can be any network, but is typically an MLP. Note that the last layer of the decoder needs to have twice the number of nodes as the data you try to model

latent_encoder Tuple[Module, Attention, Module]

a tuple of two flax.linen.Modules and an attention object. The first and last elements are the usual modules required for a neural process, the attention object computes self-attention before the aggregation

deterministic_encoder Tuple[Module, Attention, Attention]

ea tuple of a flax.linen.Module and an Attention object. The first attention object is used for self-attention, the second one is used for cross-attention

family Family

distributional family of the response variable

References

.. [1] Kim, Hyunjik, et al. "Attentive Neural Processes." International Conference on Learning Representations. 2019.

__call__(x_context: Array, y_context: Array, x_target: Array, **kwargs)

Transform the inputs through the neural process.

Parameters:

Name Type Description Default
x_context Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
y_context Array

Input data of dimension (*batch_dims, spatial_dims..., response_dims)

required
x_target Array

Input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
**kwargs

Keyword arguments can include: - y_target: jax.Array. If an argument called 'y_target' is provided, computes the loss (negative ELBO) together with a predictive posterior distribution

{}

Returns:

Type Description
Union[distribution, Tuple[distribution, float]]

If 'y_target' is provided as keyword argument, returns a tuple of the predictive distribution and the negative ELBO which can be used as loss for optimization. If 'y_target' is not provided, returns the predictive distribution only.

Functions

ramsey.train_neural_process(rng_key: jr.PRNGKey, neural_process: NP, x: Array, y: Array, n_context: Union[int, Tuple[int]], n_target: Union[int, Tuple[int]], batch_size: int, optimizer=optax.adam(0.0003), n_iter=20000, verbose=False)

Train a neural process.

Utility function to train a latent or conditional neural process, i.e., a process belonging to the NP class.

Parameters:

Name Type Description Default
rng_key PRNGKey

a key for seeding random number generators

required
neural_process NP

an object that inherits from NP

required
x Array

array of inputs. Should be a tensor of dimension :math:b \times n \times p where :math:b indexes a sequence of batches, e.g., different time series, :math:n indexes the number of observations per batch, e.g., time points, and :math:p indexes the number of feats

required
y Array

array of outputs. Should be a tensor of dimension :math:b \times n \times q where :math:b and :math:n are the same as for :math:x and :math:q is the number of outputs

required
n_context Union[int, Tuple[int]]

number of context points. If a tuple is given samples the number of context points per iteration on the interval defined by the tuple.

required
n_target Union[int, Tuple[int]]

number of target points. If a tuple is given samples the number of context points per iteration on the interval defined by the tuple. The number of target points includes the number of context points, that means, if n_context=5 and n_target=10 then the target set contains 5 more points than the context set but includes the contexts, too.

required
batch_size int

number of elements that are samples for each gradient step, i.e., number of elements in first axis of :math:x and :math:y

required
optimizer

an optax optimizer object

adam(0.0003)
n_iter

number of training iterations

20000
verbose

true if print training progress

False

Returns:

Type Description
Tuple[dict, Array]

returns a tuple of trained parameters and training loss profile