ramsey#

Module containing all implemented probabilistic models and training functions.

Models#

NP(decoder[, deterministic_encoder, ...])

A neural process.

ANP(decoder[, deterministic_encoder, ...])

An attentive neural process.

DANP(decoder[, deterministic_encoder, ...])

A doubly-attentive neural process.

Neural processes#

class ramsey.NP(decoder, deterministic_encoder=None, latent_encoder=None, family=<ramsey._src.family.Gaussian object>, *, rngs=None)[source]#

A neural process.

Implements the core structure of a vanilla (latent) neural process [1, 2].

Parameters:
  • decoder – the decoder can be any network, but is typically an MLP. Note that the _last_ layer of the decoder needs to have twice the number of nodes as the data you try to model

  • latent_encoder – the latent encoder can be any network, but is typically an MLP. The first element of the tuple is a neural network used before the aggregation step, while the second element of the tuple encodes is a neural network used to compute mean(s) and standard deviation(s) of the latent Gaussian.

  • deterministic_encoder – the deterministic encoder can be any network, but is typically an MLP

  • family – distributional family of the response variable

__call__(x_context, y_context, x_target, *, rngs=None)[source]#

Transform inputs through the neural process.

Return type:

Distribution

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the predictive distribution of y_target

loss(x_context, y_context, x_target, y_target, *, rngs=None)[source]#

Compute the loss for a set of input-output pairs.

The loss is computed by approximating the marginal likelihood via a lower bound (the ELBO) of which the negative is returned.

Return type:

Array

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_target – target output data of dimension (batch_dim, spatial_dims…, response_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the negative ELBO

class ramsey.ANP(decoder, deterministic_encoder=None, latent_encoder=None, family=<ramsey._src.family.Gaussian object>, *, rngs=None)[source]#

An attentive neural process.

Implements the core structure of an attentive neural process [3].

Parameters:
  • decoder – the decoder can be any network, but is typically an MLP. Note that the _last_ layer of the decoder needs to have twice the number of nodes as the data you try to model

  • deterministic_encoder – a tuple of a flax.nnx.Module and an Attention obj

  • latent_encoder – an optional tuple of two flax.nnx.Modules. The latent encoder can be any network, but is typically an MLP. The first element of the tuple is a neural network used before the aggregation step, while the second element of the tuple encodes is a neural network used to compute mean(s) and standard deviation(s) of the latent Gaussian.

  • family – distributional family of the response variable

__call__(x_context, y_context, x_target, *, rngs=None)#

Transform inputs through the neural process.

Return type:

Distribution

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the predictive distribution of y_target

loss(x_context, y_context, x_target, y_target, *, rngs=None)#

Compute the loss for a set of input-output pairs.

The loss is computed by approximating the marginal likelihood via a lower bound (the ELBO) of which the negative is returned.

Return type:

Array

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_target – target output data of dimension (batch_dim, spatial_dims…, response_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the negative ELBO

class ramsey.DANP(decoder, deterministic_encoder=None, latent_encoder=None, family=<ramsey._src.family.Gaussian object>, *, rngs=None)[source]#

A doubly-attentive neural process.

Implements the core structure of a ‘doubly-attentive’ neural process [3].

Parameters:
  • decoder – the decoder can be any network, but is typically an MLP. Note that the _last_ layer of the decoder needs to have twice the number of nodes as the data you try to model

  • latent_encoder – a tuple of two flax.nnx.Modules and an attention object. The first and last elements are the usual modules required for a neural process, the attention object computes self-attention before the aggregation

  • deterministic_encoder – a tuple of a`flax.nnx.Module and an Attention object. The first Attention object is used for self-attention, the second one is used for cross-attention

  • family – distributional family of the response variable

__call__(x_context, y_context, x_target, *, rngs=None)#

Transform inputs through the neural process.

Return type:

Distribution

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the predictive distribution of y_target

loss(x_context, y_context, x_target, y_target, *, rngs=None)#

Compute the loss for a set of input-output pairs.

The loss is computed by approximating the marginal likelihood via a lower bound (the ELBO) of which the negative is returned.

Return type:

Array

Parameters:
  • x_context – context input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_context – context output data of dimension (batch_dim, spatial_dims…, response_dim)

  • x_target – target input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • y_target – target output data of dimension (batch_dim, spatial_dims…, response_dim)

  • rngs – a rnglib.Rngs object for random seeds

Returns:

returns the negative ELBO

Train functions#

train_neural_process(rng_key, ...[, ...])

Train a neural process.

ramsey.train_neural_process(rng_key, neural_process, x, y, n_context, n_target, batch_size, optimizer=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), n_iter=20000, verbose=False)[source]#

Train a neural process.

Utility function to train a latent or conditional neural process, i.e., a process belonging to the NP class.

Parameters:
  • rng_key – a key for seeding random number generators

  • neural_process – an object that inherits from NP

  • x – array of inputs. Should be a tensor of dimension \(b \times n \times p\) where \(b\) indexes a sequence of batches, e.g., different time series, \(n\) indexes the number of observations per batch, e.g., time points, and \(p\) indexes the number of features

  • y – array of outputs. Should be a tensor of dimension \(b \times n \times q\) where \(b\) and \(n\) are the same as for \(x\) and \(q\) is the number of outputs

  • n_context – number of context points. If a tuple is given samples the number of context points per iteration on the interval defined by the tuple.

  • n_target – number of target points. If a tuple is given samples the number of context points per iteration on the interval defined by the tuple. The number of target points includes the number of context points, that means, if n_context=5 and n_target=10 then the target set contains 5 more points than the context set but includes the contexts, too.

  • batch_size – number of elements that are samples for each gradient step, i.e., number of elements in first axis of \(x\) and \(y\)

  • optimizer – an optax optimizer object

  • n_iter – number of training iterations

  • verbose – true if print training progress

Returns:

returns a tuple of trained parameters and training loss profile