Skip to content

ramsey.nn

Neural networks and other modules for building neural processes, Bayesian neural networks, etc.

Modules

ramsey.nn.MLP

Bases: Module

A multi-layer perceptron.

Attributes:

Name Type Description
output_sizes Iterable[int]

number of hidden nodes per layer

dropout Optional[float]

dropout rate to apply after each hidden layer

kernel_init Initializer

initializer for weights of hidden layers

bias_init Initializer

initializer for bias of hidden layers

use_bias bool

boolean if hidden layers should use bias nodes

activation Callable

activation function to apply after each hidden layer. Default is relu.

activate_final bool

if true, activate last layer

activate_final: bool = False class-attribute instance-attribute

activation: Callable = jax.nn.relu class-attribute instance-attribute

bias_init: initializers.Initializer = initializers.zeros_init() class-attribute instance-attribute

dropout: Optional[float] = None class-attribute instance-attribute

kernel_init: initializers.Initializer = default_kernel_init class-attribute instance-attribute

output_sizes: Iterable[int] instance-attribute

use_bias: bool = True class-attribute instance-attribute

__call__(inputs: Array, is_training: bool = False)

Transform the inputs through the MLP.

Parameters:

Name Type Description Default
inputs Array

input data of dimension (*batch_dims, spatial_dims..., feature_dims)

required
is_training bool

if true, uses training mode (i.e., dropout)

False

Returns:

Type Description
Array

returns the transformed inputs

setup()

Construct all networks.

ramsey.nn.MultiHeadAttention

Bases: Attention

Multi-head attention.

As described in [1].

Attributes:

Name Type Description
num_heads int

number of heads

head_size int

size of the heads for keys, values and queries

embedding Module

neural network module to embed keys and queries before attention

References

.. [1] Vaswani, Ashish, et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.

embedding: Optional[nn.Module] instance-attribute

head_size: int instance-attribute

num_heads: int instance-attribute

__call__(key: Array, value: Array, query: Array) -> Array

Apply attention to the query.

Arguments

key: jax.Array key value: jax.Array value query: jax.Array query

Returns:

Type Description
Array

returns attended query

setup() -> None

Construct the networks.