import jax
from flax import nnx
from flax.nnx import rnglib
from flax.typing import Dtype
from jax import numpy as jnp
from ramsey._src.experimental.kernel.base import Kernel
[docs]
class Periodic(Kernel, nnx.Module):
"""Periodic covariance function.
Args:
period: the period of the periodic kernel
active_dims: either None or a list of integers.
Specified the dimensions of the data on which the kernel operates on
rho_init: an initializer object
sigma_init: an initializer object from Haiku or None
param_dtype: parameter type
rngs: a random seed generator
"""
def __init__(
self,
period,
active_dims: list | None = None,
*,
rho_init: nnx.initializers.Initializer = nnx.initializers.constant(
jnp.log(1.0)
),
sigma_init: nnx.initializers.Initializer = nnx.initializers.constant(
jnp.log(1.0)
),
param_dtype: Dtype = jnp.float32,
rngs: rnglib.Rngs,
):
self.period = period
self.log_rho = nnx.Param(rho_init(rngs.params(), (), param_dtype))
self.log_sigma = nnx.Param(sigma_init(rngs.params(), (), param_dtype))
self._active_dims = (
active_dims if isinstance(active_dims, list) else slice(active_dims)
)
[docs]
def __call__(self, x1: jax.Array, x2: jax.Array = None):
if x2 is None:
x2 = x1
cov = periodic(
x1[..., self._active_dims],
x2[..., self._active_dims],
self.period,
jnp.exp(self.log_sigma),
jnp.exp(self.log_rho),
)
return cov
[docs]
class ExponentiatedQuadratic(Kernel, nnx.Module):
"""Exponentiated quadratic covariance function.
Args:
active_dims: either None or a list of integers. Specified the dimensions
of the data on which the kernel operates on
rho_init: Optional[Initializer]
an initializer object from Haiku or None
sigma_init: Optional[Initializer]
an initializer object from Haiku or None
param_dtype: parameter type
rngs: a random seed generator
"""
def __init__(
self,
active_dims: list | None = None,
*,
rho_init: nnx.initializers.Initializer = nnx.initializers.constant(
jnp.log(1.0)
),
sigma_init: nnx.initializers.Initializer = nnx.initializers.constant(
jnp.log(1.0)
),
param_dtype: Dtype = jnp.float32,
rngs: rnglib.Rngs,
):
self.log_rho = nnx.Param(rho_init(rngs.params(), (), param_dtype))
self.log_sigma = nnx.Param(sigma_init(rngs.params(), (), param_dtype))
self._active_dims = (
active_dims if isinstance(active_dims, list) else slice(active_dims)
)
[docs]
def __call__(self, x1: jax.Array, x2: jax.Array = None):
if x2 is None:
x2 = x1
cov = exponentiated_quadratic(
x1[..., self._active_dims],
x2[..., self._active_dims],
jnp.square(jnp.exp(self.log_sigma)),
jnp.exp(self.log_rho),
)
return cov
# pylint: disable=invalid-name
[docs]
def exponentiated_quadratic(
x1: jax.Array,
x2: jax.Array,
sigma: float,
rho: float | jax.Array,
):
"""Exponentiated-quadratic convariance function.
Args:
x1: (`n x p`)-dimensional set of data points
x2: (`m x p`)-dimensional set of data points
sigma: the standard deviation of the kernel function
rho: the length-scale of the kernel function. Can be a float or a
:math:`p`-dimensional vector if ARD-behaviour is desired
Returns:
returns a (`n x m`)-dimensional kernel matrix
"""
def _exponentiated_quadratic(x, y, sigma, rho):
x_e = jnp.expand_dims(x, 1) / rho
y_e = jnp.expand_dims(y, 0) / rho
d = jnp.sum(jnp.square(x_e - y_e), axis=2)
K = sigma * jnp.exp(-0.5 * d)
return K
return _exponentiated_quadratic(x1, x2, sigma, rho)
# pylint: disable=invalid-name
[docs]
def periodic(x1: jax.Array, x2: jax.Array, period, sigma, rho):
"""Periodic convariance function.
Args:
x1: (`n x p`)-dimensional set of data points
x2: (`m x p`)-dimensional set of data points
period: the period
sigma: the standard deviation of the kernel function
rho: the length-scale of the kernel function. Can be a float or a
:math:`p`-dimensional vector if ARD-behaviour is desired
Returns:
returns a (`n x m`)-dimensional Gram matrix
"""
def _periodic(x, y, period, sigma, rho):
x_e = jnp.expand_dims(x, 1)
y_e = jnp.expand_dims(y, 0)
r2 = jnp.sum((x_e - y_e) ** 2, axis=2)
r = jnp.sqrt(r2)
K = sigma * jnp.exp(-2 / rho**2 * jnp.sin(jnp.pi * r / period) ** 2)
return K
return _periodic(x1, x2, period, sigma, rho)