ramsey.nn
Neural networks and other modules for building neural processes, Bayesian neural networks, etc.
Modules
ramsey.nn.MLP
Bases: Module
A multi-layer perceptron.
Attributes:
Name | Type | Description |
---|---|---|
output_sizes |
Iterable[int]
|
number of hidden nodes per layer |
dropout |
Optional[float]
|
dropout rate to apply after each hidden layer |
kernel_init |
Initializer
|
initializer for weights of hidden layers |
bias_init |
Initializer
|
initializer for bias of hidden layers |
use_bias |
bool
|
boolean if hidden layers should use bias nodes |
activation |
Callable
|
activation function to apply after each hidden layer. Default is relu. |
activate_final |
bool
|
if true, activate last layer |
activate_final: bool = False
class-attribute
instance-attribute
activation: Callable = jax.nn.relu
class-attribute
instance-attribute
bias_init: initializers.Initializer = initializers.zeros_init()
class-attribute
instance-attribute
dropout: Optional[float] = None
class-attribute
instance-attribute
kernel_init: initializers.Initializer = default_kernel_init
class-attribute
instance-attribute
output_sizes: Iterable[int]
instance-attribute
use_bias: bool = True
class-attribute
instance-attribute
__call__(inputs: Array, is_training: bool = False)
Transform the inputs through the MLP.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
inputs |
Array
|
input data of dimension (*batch_dims, spatial_dims..., feature_dims) |
required |
is_training |
bool
|
if true, uses training mode (i.e., dropout) |
False
|
Returns:
Type | Description |
---|---|
Array
|
returns the transformed inputs |
setup()
Construct all networks.
ramsey.nn.MultiHeadAttention
Bases: Attention
Multi-head attention.
As described in [1].
Attributes:
Name | Type | Description |
---|---|---|
num_heads |
int
|
number of heads |
head_size |
int
|
size of the heads for keys, values and queries |
embedding |
Module
|
neural network module to embed keys and queries before attention |
References
.. [1] Vaswani, Ashish, et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017.
embedding: Optional[nn.Module]
instance-attribute
head_size: int
instance-attribute
num_heads: int
instance-attribute
__call__(key: Array, value: Array, query: Array) -> Array
Apply attention to the query.
Arguments
key: jax.Array key value: jax.Array value query: jax.Array query
Returns:
Type | Description |
---|---|
Array
|
returns attended query |
setup() -> None
Construct the networks.