Source code for ramsey._src.nn.attention.multihead_attention

import dataclasses
from collections.abc import Callable

import flax
import jax
from flax import nnx
from flax.nnx import rnglib

from ramsey._src.nn.attention.attention import Attention


[docs] @dataclasses.dataclass class MultiHeadAttention(Attention): """Multi-head attention. As described in [1]. Args: in_features: int num_heads: number of heads embedding: neural network module to embed keys and queries before attention rngs: a rnglib.Rngs object for random seeds """ def __init__( self, in_features: int, num_heads: int, embedding: flax.nnx.Module | Callable = lambda x: x, *, rngs: rnglib.Rngs | None = None, ): """Construct the networks.""" super().__init__(embedding) self._attention = nnx.MultiHeadAttention( in_features=in_features, num_heads=num_heads, decode=False, rngs=rngs, )
[docs] def __call__( self, key: jax.Array, value: jax.Array, query: jax.Array, *, rngs: rnglib.Rngs | None = None, ) -> jax.Array: """Apply attention to the query. Args: key: the key :) value: the value :) query: the query :) rngs: a nnx random key Returns: returns attended query """ key, value, query = super().__call__(key, value, query) rep = self._attention(query, key, value, rngs=rngs) self._check_return_dimension(rep, value, query) return rep