ramsey.nn#
Neural networks and other modules for building neural processes, Bayesian neural networks, etc.
Modules#
|
A multi-layer perceptron. |
|
Multi-head attention. |
MLP#
- class ramsey.nn.MLP(input_size, output_sizes, dropout=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, rngs=None)[source]#
A multi-layer perceptron.
- Parameters:
output_sizes – number of hidden nodes per layer
dropout – dropout rate to apply after each hidden layer
kernel_init – initializer for weights of hidden layers
bias_init – initializer for bias of hidden layers
use_bias – boolean if hidden layers should use bias nodes
activation – activation function to apply after each hidden layer
activate_final – if true, activate last layer
rngs – a random seed generator
- __call__(inputs, is_training=False, *, rngs=None)[source]#
Transform the inputs through the MLP.
- Return type:
Array- Parameters:
inputs – input data of dimension (batch_dim, spatial_dims…, feature_dim)
is_training – if true, uses training mode (i.e., dropout)
rngs – a random seed generator
- Returns:
returns the transformed inputs
MultiHeadAttention#
- class ramsey.nn.MultiHeadAttention(in_features, num_heads, embedding=<function MultiHeadAttention.<lambda>>, *, rngs=None)[source]#
Multi-head attention.
As described in [1].
- Parameters:
in_features – int
num_heads – number of heads
embedding – neural network module to embed keys and queries before attention
rngs – a rnglib.Rngs object for random seeds