|
| 1 | +""" |
| 2 | +`continuiti.networks.multi_head_attention` |
| 3 | +
|
| 4 | +Multi-Head-Attention in continuiti. |
| 5 | +""" |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | + |
| 10 | +from .attention import Attention |
| 11 | +from .scaled_dot_product_attention import ScaledDotProductAttention |
| 12 | + |
| 13 | + |
| 14 | +class MultiHeadAttention(Attention): |
| 15 | + r"""Multi-Head Attention module. |
| 16 | +
|
| 17 | + Module as described in the paper [Attention is All you |
| 18 | + Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf) |
| 19 | + with optional bias for the projections. This implementation allows to use |
| 20 | + attention implementations other than the standard scaled dot product |
| 21 | + attention implemented by the MultiheadAttention PyTorch module. |
| 22 | +
|
| 23 | + $$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$ |
| 24 | +
|
| 25 | + where |
| 26 | +
|
| 27 | + $$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$ |
| 28 | +
|
| 29 | + Args: |
| 30 | + hidden_dim: dimension of the hidden layers (embedding dimension). |
| 31 | + n_heads: number of attention heads. |
| 32 | + attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments |
| 33 | + `query`, `key`, `value`, `attn_mask`, and `dropout_p`. |
| 34 | + dropout_p: dropout probability. |
| 35 | + bias: If True, then the projection onto the different heads is performed with bias. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + hidden_dim: int, |
| 41 | + n_heads: int, |
| 42 | + attention: Attention = None, |
| 43 | + dropout_p: float = 0, |
| 44 | + bias: bool = True, |
| 45 | + ): |
| 46 | + super().__init__() |
| 47 | + |
| 48 | + self.hidden_dim = hidden_dim |
| 49 | + self.n_heads = n_heads |
| 50 | + self.dropout_p = dropout_p |
| 51 | + self.bias = bias |
| 52 | + |
| 53 | + if attention is None: |
| 54 | + attention = ScaledDotProductAttention() |
| 55 | + self.attention = attention |
| 56 | + |
| 57 | + self.head_dim = hidden_dim // n_heads |
| 58 | + assert ( |
| 59 | + self.head_dim * n_heads == hidden_dim |
| 60 | + ), "hidden_dim must be divisible by n_heads" |
| 61 | + |
| 62 | + # projection networks |
| 63 | + self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) |
| 64 | + self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) |
| 65 | + self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) |
| 66 | + self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias) |
| 67 | + |
| 68 | + def forward( |
| 69 | + self, |
| 70 | + query: torch.Tensor, |
| 71 | + key: torch.Tensor, |
| 72 | + value: torch.Tensor, |
| 73 | + attn_mask: torch.Tensor = None, |
| 74 | + ) -> torch.Tensor: |
| 75 | + r"""Compute the attention scores. |
| 76 | +
|
| 77 | + Args: |
| 78 | + query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim). |
| 79 | + key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim). |
| 80 | + value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim). |
| 81 | + attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length). |
| 82 | +
|
| 83 | + Returns: |
| 84 | + Attention scores of shape (batch_size, target_sequence_length, hidden_dim). |
| 85 | + """ |
| 86 | + assert query.ndim == key.ndim == value.ndim == 3, ( |
| 87 | + "Query, key, and value need to have three dimensions (batch_size, ..., hidden_dim). This format ensures that" |
| 88 | + "the module can correctly apply the multi-head attention mechanism, which includes splitting embeddings " |
| 89 | + "into multiple heads, applying the internal attention implementation for each head, concatenating and " |
| 90 | + "projecting results, while ensuring that the attention mask is applied correctly." |
| 91 | + ) |
| 92 | + assert ( |
| 93 | + query.size(0) == key.size(0) == value.size(0) |
| 94 | + ), "Batch size does not match for input tensors" |
| 95 | + assert ( |
| 96 | + query.size(-1) == key.size(-1) == value.size(-1) |
| 97 | + ), "Embedding/hidden dimension does not match for input tensors" |
| 98 | + |
| 99 | + batch_size = query.size(0) |
| 100 | + src_len = key.size(1) |
| 101 | + tgt_len = query.size(1) |
| 102 | + |
| 103 | + # project values |
| 104 | + query = self.query_project(query) |
| 105 | + key = self.key_project(key) |
| 106 | + value = self.value_project(value) |
| 107 | + |
| 108 | + # form individual heads |
| 109 | + query = query.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) |
| 110 | + key = key.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) |
| 111 | + value = value.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2) |
| 112 | + |
| 113 | + # reshape attention mask to match heads |
| 114 | + if attn_mask is not None: |
| 115 | + assert ( |
| 116 | + attn_mask.size(0) == batch_size |
| 117 | + ), "Attention mask batch size does not match input tensors." |
| 118 | + assert ( |
| 119 | + attn_mask.size(1) == tgt_len |
| 120 | + ), "First dimension of the attention mask needs to match target length." |
| 121 | + assert ( |
| 122 | + attn_mask.size(2) == src_len |
| 123 | + ), "Second dimension of the attention mask needs to match source length." |
| 124 | + |
| 125 | + attn_mask = attn_mask.unsqueeze(1) # mask for a single head |
| 126 | + attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head |
| 127 | + |
| 128 | + # perform attention |
| 129 | + attn_out = self.attention( |
| 130 | + query=query, |
| 131 | + key=key, |
| 132 | + value=value, |
| 133 | + attn_mask=attn_mask, |
| 134 | + ) |
| 135 | + attn_out = attn_out.transpose(1, 2).reshape(batch_size, -1, self.hidden_dim) |
| 136 | + |
| 137 | + # output projection |
| 138 | + return self.out_project(attn_out) |
0 commit comments