Source code for ramsey._src.family
import abc
import jax
from jax import numpy as jnp
from numpyro import distributions as nd
class Family(abc.ABC):
"""Distributional family."""
@abc.abstractmethod
def __call__(self, target: jnp.ndarray, **kwargs):
"""Compose a NumPyro distribution."""
[docs]
class Gaussian(Family):
"""Family of Gaussian distributions."""
[docs]
def __call__(
self, target: jax.Array, log_scale: jax.Array | None = None, **kwargs
) -> nd.Distribution:
"""Compose a NumPyro distribution."""
if log_scale is not None:
mean = target
scale = jnp.exp(log_scale)
else:
mean, log_scale = jnp.split(target, 2, axis=-1)
scale = 0.1 + 0.9 * jax.nn.softplus(log_scale)
return nd.Normal(loc=mean, scale=scale)
[docs]
class NegativeBinomial(Family):
"""Family of negative binomial distributions."""
[docs]
def __call__(
self,
target: jax.Array,
log_concentration: jax.Array | None = None,
**kwargs,
) -> nd.Distribution:
if log_concentration is None:
mean, log_concentration = jnp.split(target, 2, axis=-1)
else:
mean = target
mean = jnp.exp(mean)
concentration = jnp.exp(log_concentration)
return nd.NegativeBinomial2(mean=mean, concentration=concentration)