Source code for ramsey._src.nn.MLP

import dataclasses
from collections.abc import Callable, Iterable

import jax
from flax import nnx
from flax.nnx import rnglib


[docs] @dataclasses.dataclass class MLP(nnx.Module): """A multi-layer perceptron. Args: output_sizes: number of hidden nodes per layer dropout: dropout rate to apply after each hidden layer kernel_init: initializer for weights of hidden layers bias_init: initializer for bias of hidden layers use_bias: boolean if hidden layers should use bias nodes activation: activation function to apply after each hidden layer activate_final: if true, activate last layer rngs: a random seed generator """ input_size: int output_sizes: Iterable[int] dropout: float | None = None kernel_init: nnx.initializers.Initializer = nnx.initializers.lecun_normal() bias_init: nnx.initializers.Initializer = nnx.initializers.zeros_init() use_bias: bool = True activation: Callable = jax.nn.relu activate_final: bool = False rngs: rnglib.Rngs | None = None def __post_init__(self): """Construct all networks.""" output_sizes = (self.input_size,) + tuple(self.output_sizes) layers = [] for index, (din, dout) in enumerate( zip(output_sizes[:-1], output_sizes[1:]) ): layers.append( nnx.Linear( in_features=din, out_features=dout, kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, rngs=self.rngs, ) ) self.layers = tuple(layers) if self.dropout is not None: self.dropout_layer = nnx.Dropout(self.dropout)
[docs] def __call__( self, inputs: jax.Array, is_training: bool = False, *, rngs: rnglib.Rngs | None = None, ) -> jax.Array: """Transform the inputs through the MLP. Args: inputs: input data of dimension (batch_dim, spatial_dims..., feature_dim) is_training: if true, uses training mode (i.e., dropout) rngs: a random seed generator Returns: returns the transformed inputs """ num_layers = len(self.layers) out = inputs for i, layer in enumerate(self.layers): out = layer(out) if i < (num_layers - 1) or self.activate_final: if self.dropout is not None: out = self.dropout_layer(out, deterministic=not is_training) out = self.activation(out) return out