Source code for ramsey._src.experimental.kernel.non_stationary

import jax
from flax import nnx
from flax.nnx import rnglib
from flax.typing import Dtype
from jax import numpy as jnp

from ramsey._src.experimental.kernel.base import Kernel


[docs] class Linear(Kernel, nnx.Module): """Linear covariance function. Args: active_dims: the indexes of the dimensions the kernel acts upon sigma_b_init: an initializer object from Flax or None sigma_v_init: an initializer object from Flax or None offset_init: an initializer object from Flax or None rngs: a random seed generator """ def __init__( self, active_dims: list | None = None, *, sigma_b_init: nnx.initializers.Initializer = nnx.initializers.uniform(), sigma_v_init: nnx.initializers.Initializer = nnx.initializers.uniform(), offset_init: nnx.initializers.Initializer = nnx.initializers.zeros_init(), param_dtype: Dtype = jnp.float32, rngs: rnglib.Rngs, ): self._active_dims = ( active_dims if isinstance(active_dims, list) else slice(active_dims) ) self.log_sigma_b = nnx.Param(sigma_b_init(rngs.params(), (), param_dtype)) self.log_sigma_v = nnx.Param(sigma_v_init(rngs.params(), (), param_dtype)) self.offset = nnx.Param(offset_init(rngs.params(), (), param_dtype))
[docs] def __call__(self, x1: jax.Array, x2: jax.Array = None): """Call the covariance function.""" if x2 is None: x2 = x1 cov = linear( x1[..., self._active_dims], x2[..., self._active_dims], jnp.exp(self.log_sigma_b.value), jnp.exp(self.log_sigma_v.value), self.offset.value, ) return cov
[docs] def linear( x1: jax.Array, x2: jax.Array, sigma_b: float, sigma_v: float, offset: float ): r"""Linear convariance function. Args: x1: :math:`n x p`-dimensional set of data points x2: :math:`m x p`-dimensional set of data points sigma_b: the standard deviation of the kernel function sigma_v: the standard deviation of the kernel function offset: float Returns: returns a :math:`n x m`-dimensional Gram matrix """ def _linear(x1, x2, sigma_b, sigma_v, offset): x_e = x1 - offset y_e = x2 - offset x_e = jnp.expand_dims(x_e, 1) y_e = jnp.expand_dims(y_e, 0) d = jnp.sum(x_e * y_e, axis=2) K = sigma_v**2 * d + sigma_b**2 return K return _linear(x1, x2, sigma_b, sigma_v, offset)