Source code for ramsey._src.data.data

from collections import namedtuple

import jax
import pandas as pd
from jax import numpy as jnp
from jax import random as jr
from numpyro import distributions as dist

from ramsey._src.data.dataset_m4 import M4Dataset
from ramsey._src.experimental.kernel.stationary import exponentiated_quadratic


# pylint: disable=too-many-locals,invalid-name
[docs] def m4_data(interval: str = "hourly", drop_na: bool = True): """Load a data set from the M4 competition. Args: interval: either of "hourly", "daily", "weekly", "monthly", "yearly" drop_na: drop rows that contain NA values Returns: returns a named tuple with outputs (y), inputs (x), and training and testing indexes for the input-output paris """ train, test = M4Dataset().load(interval) df = pd.concat([train, test.reindex(train.index)], axis=1) if drop_na: df = df.dropna() y = df.values y = y.reshape((*y.shape, 1)) x = jnp.arange(y.shape[1]) / train.shape[1] x = jnp.tile(x, [y.shape[0], 1]).reshape((y.shape[0], y.shape[1], 1)) train_idxs = jnp.arange(train.shape[1]) test_idxs = jnp.arange(test.shape[1]) + train.shape[1] return namedtuple("data", ["y", "x", "train_idxs", "test_idxs"])( # type: ignore y, x, train_idxs, test_idxs )
# pylint: disable=too-many-locals,invalid-name
[docs] def sample_from_sine_function( rng_key: jax.Array, batch_size: int = 10, num_observations: int = 100 ): r"""Sample from a noisy sine function. Creates samples from a noisy sine functions. For each batch, chooses a new hyper-parameters configuration. The inputs, `x` of the sine function have dimensionality :math:`b \times n \times 1`, where `b` is the batch size and `n` is the number of observations per batch. The outputs and latent functions realizations have dimension :math:`b \times n \times 1` as well. Args: rng_key: a JAX random key for seeding batch_size: size of batch num_observations: number of observations per batch rho: the lengthscale of the kernel function sigma: the standard deviation of the kernel function Returns: a tuple consisting of outputs (y), inputs (x) and latent GP realization (f) where """ x = jnp.linspace(-jnp.pi, jnp.pi, num_observations).reshape( (num_observations, 1) ) ys = [] fs = [] for _ in range(batch_size): sample_key1, sample_key2, sample_key3, rng_key = jr.split(rng_key, 4) a = 2 * jr.uniform(sample_key1) - 1 b = jr.uniform(sample_key2) - 0.5 f = a * jnp.sin(x - b) y = f + jr.normal(sample_key3, shape=(num_observations, 1)) * 0.10 fs.append(f.reshape((1, num_observations, 1))) ys.append(y.reshape((1, num_observations, 1))) x = jnp.tile(x, [batch_size, 1, 1]) y = jnp.vstack(jnp.array(ys)) f = jnp.vstack(jnp.array(fs)) return namedtuple("data", "y x f")(y, x, f) # type: ignore[call-arg]
# pylint: disable=too-many-locals,invalid-name
[docs] def sample_from_gaussian_process( rng_key: jax.Array, batch_size: int = 10, num_observations: int = 100, rho: float | None = None, sigma: float | None = None, ): r"""Sample from a Gaussian process. Creates samples from a Gaussian process with exponentiated quadratic covariance function. For each batch, chooses a new hyperparameter configuration where `rho`, the kernel lengthscale is drawn from an `InverseGamma(1, 1)` and sigma, the kernel lengthscale, is drawn from an `InverseGamma(5, 5)`. The inputs, `x` of the Gaussian process have dimensionality :math:`b \times n \times 1`, where `b` is the batch size and `n` is the number of observations per batch. The outputs and latent functions realizations have dimension :math:`b \times n \times 1` as well. Args: rng_key: a random key for seeding batch_size: size of batch num_observations: number of observations per batch rho: the lengthscale of the kernel function sigma: the standard deviation of the kernel function Returns: a tuple consisting of outputs (y), inputs (x) and latent GP realization (f) where """ x = jnp.linspace(-jnp.pi, jnp.pi, num_observations).reshape( (num_observations, 1) ) ys = [] fs = [] for _ in range(batch_size): sample_key1, sample_key2, sample_key3, sample_key4, rng_key = jr.split( rng_key, 5 ) if rho is None: rho = dist.InverseGamma(1, 1).sample(sample_key1) if sigma is None: sigma = dist.InverseGamma(5, 5).sample(sample_key2) K = exponentiated_quadratic(x, x, sigma, rho) f = jr.multivariate_normal( sample_key3, mean=jnp.zeros(num_observations), cov=K + jnp.diag(jnp.ones(num_observations)) * 1e-5, ) y = jr.multivariate_normal( sample_key4, mean=f, cov=jnp.eye(num_observations) * 0.05 ) fs.append(f.reshape((1, num_observations, 1))) ys.append(y.reshape((1, num_observations, 1))) x = jnp.tile(x, [batch_size, 1, 1]) y = jnp.vstack(jnp.array(ys)) f = jnp.vstack(jnp.array(fs)) return namedtuple("data", "y x f")(y, x, f) # type: ignore[call-arg]