ramsey.nn#

Neural networks and other modules for building neural processes, Bayesian neural networks, etc.

Modules#

MLP(input_size, output_sizes[, dropout, ...])

A multi-layer perceptron.

MultiHeadAttention(in_features, num_heads[, ...])

Multi-head attention.

MLP#

class ramsey.nn.MLP(input_size, output_sizes, dropout=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, use_bias=True, activation=<jax._src.custom_derivatives.custom_jvp object>, activate_final=False, rngs=None)[source]#

A multi-layer perceptron.

Parameters:
  • output_sizes – number of hidden nodes per layer

  • dropout – dropout rate to apply after each hidden layer

  • kernel_init – initializer for weights of hidden layers

  • bias_init – initializer for bias of hidden layers

  • use_bias – boolean if hidden layers should use bias nodes

  • activation – activation function to apply after each hidden layer

  • activate_final – if true, activate last layer

  • rngs – a random seed generator

__call__(inputs, is_training=False, *, rngs=None)[source]#

Transform the inputs through the MLP.

Return type:

Array

Parameters:
  • inputs – input data of dimension (batch_dim, spatial_dims…, feature_dim)

  • is_training – if true, uses training mode (i.e., dropout)

  • rngs – a random seed generator

Returns:

returns the transformed inputs

MultiHeadAttention#

class ramsey.nn.MultiHeadAttention(in_features, num_heads, embedding=<function MultiHeadAttention.<lambda>>, *, rngs=None)[source]#

Multi-head attention.

As described in [1].

Parameters:
  • in_features – int

  • num_heads – number of heads

  • embedding – neural network module to embed keys and queries before attention

  • rngs – a rnglib.Rngs object for random seeds

__call__(key, value, query, *, rngs=None)[source]#

Apply attention to the query.

Return type:

Array

Parameters:
  • key – the key :)

  • value – the value :)

  • query – the query :)

  • rngs – a nnx random key

Returns:

returns attended query