πŸ‘‹ Welcome to Ramsey!

πŸ‘‹ Welcome to Ramsey!#

Probabilistic deep learning using JAX

Ramsey is a library for probabilistic modelling using JAX , Flax and NumPyro.

Ramsey’s scope covers

  • neural processes (vanilla, attentive, Markovian, convolutional, …),

  • neural Laplace and Fourier operator models,

  • flow matching and denoising diffusion models,

  • etc.

Example#

You can, for instance, construct a simple neural process like this:

from flax import nnx

from ramsey import NP
from ramsey.nn import MLP

def get_neural_process(in_features, out_features):
  dim = 128
  np = NP(
    decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(0)),
    latent_encoder=(
      MLP(in_features, [dim, dim], rngs=nnx.Rngs(1)),
      MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(2))
    )
  )
  return np

neural_process = get_neural_process(1, 1)

The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically flax.nnx MLPs, but Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train the model by accessing the ELBO given input-output pairs via

from jax import random as jr
from ramsey.data import sample_from_sine_function

key = jr.PRNGKey(0)
data = sample_from_sine_function(key)

x_context, y_context = data.x[:, :20, :],  data.y[:, :20, :]
x_target, y_target = data.x, data.y
loss = neural_process.loss(
  x_context=x_context,
  y_context=y_context,
  x_target=x_target,
  y_target=y_target
)

Making predictions can be done like this:

pred = neural_process(x_context=x_context, y_context=y_context, x_target=x_target)

Why Ramsey#

Just as the names of other probabilistic languages are inspired by researchers in the field (e.g., Stan, Edward, Turing), Ramsey takes its name from one of my favourite philosophers/mathematicians, Frank Ramsey.

Installation#

To install from PyPI, call:

pip install ramsey

To install the latest GitHub <RELEASE>, just call the following on the command line:

pip install git+https://github.com/ramsey-devs/ramsey@<RELEASE>

See also the installation instructions for JAX, if you plan to use Ramsey on GPU/TPU.

Contributing#

Contributions in the form of pull requests are more than welcome. A good way to start is to check out issues labelled β€œgood first issue”.

In order to contribute:

  1. Clone Ramsey and install it and the package manager uv from here.

  2. create a new branch locally via git checkout -b feature/my-new-feature or git checkout -b issue/fixes-bug,

  3. install all dependencies via uv sync –all-extras,

  4. implement your contribution,

  5. test it by calling make format, make lints and make tests on the (Unix) command line,

  6. submit a PR πŸ™‚

License#

Ramsey is licensed under the Apache 2.0 License.