Skip to content

Commit f07fe22

Browse files
author
Samuel Burbulla
authored
Merge pull request #108 from aai-institute/feature/multi-head-attention
Feature: Multi-Head Attention
2 parents 03f8918 + d5f7ac3 commit f07fe22

File tree

7 files changed

+449
-2
lines changed

7 files changed

+449
-2
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# CHANGELOG
22

3-
## 0.1
3+
## 0.2.0
4+
5+
- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
6+
7+
## 0.1.0
48

59
- Move all content of `__init__.py` files to sub-modules.
610
- Add `Trainer` class to replace `operator.fit` method.

src/continuiti/networks/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,12 @@
66

77
from .fully_connected import FullyConnected
88
from .deep_residual_network import DeepResidualNetwork
9+
from .multi_head_attention import MultiHeadAttention
10+
from .scaled_dot_product_attention import ScaledDotProductAttention
911

10-
__all__ = ["FullyConnected", "DeepResidualNetwork"]
12+
__all__ = [
13+
"FullyConnected",
14+
"DeepResidualNetwork",
15+
"MultiHeadAttention",
16+
"ScaledDotProductAttention",
17+
]

src/continuiti/networks/attention.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""
2+
`continuiti.networks.attention`
3+
4+
Attention base class in continuiti.
5+
"""
6+
7+
from abc import abstractmethod
8+
import torch.nn as nn
9+
import torch
10+
11+
12+
class Attention(nn.Module):
13+
"""Base class for various attention implementations.
14+
15+
Attention assigns different parts of an input varying importance without set
16+
kernels. The importance of different components is designated using "soft"
17+
weights. These weights are assigned according to specific algorithms (e.g.
18+
scaled-dot-product attention).
19+
"""
20+
21+
def __init__(self):
22+
super().__init__()
23+
24+
@abstractmethod
25+
def forward(
26+
self,
27+
query: torch.Tensor,
28+
key: torch.Tensor,
29+
value: torch.Tensor,
30+
attn_mask: torch.Tensor = None,
31+
) -> torch.Tensor:
32+
"""Calculates the attention scores.
33+
34+
Args:
35+
query: query tensor; shape (batch_size, target_seq_length, hidden_dim)
36+
key: key tensor; shape (batch_size, source_seq_length, hidden_dim)
37+
value: value tensor; shape (batch_size, source_seq_length, hidden_dim)
38+
attn_mask: tensor indicating which values are used to calculate the output;
39+
shape (batch_size, target_seq_length, source_seq_length)
40+
41+
Returns:
42+
tensor containing the outputs of the attention implementation;
43+
shape (batch_size, target_seq_length, hidden_dim)
44+
"""
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""
2+
`continuiti.networks.multi_head_attention`
3+
4+
Multi-Head-Attention in continuiti.
5+
"""
6+
7+
import torch
8+
import torch.nn as nn
9+
10+
from .attention import Attention
11+
from .scaled_dot_product_attention import ScaledDotProductAttention
12+
13+
14+
class MultiHeadAttention(Attention):
15+
r"""Multi-Head Attention module.
16+
17+
Module as described in the paper [Attention is All you
18+
Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
19+
with optional bias for the projections. This implementation allows to use
20+
attention implementations other than the standard scaled dot product
21+
attention implemented by the MultiheadAttention PyTorch module.
22+
23+
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$
24+
25+
where
26+
27+
$$head_i=Attention(QW_i^Q+b_i^Q, KW_i^K+b_i^K, VW_i^V+b_i^V).$$
28+
29+
Args:
30+
hidden_dim: dimension of the hidden layers (embedding dimension).
31+
n_heads: number of attention heads.
32+
attention: implementation of attention (defaults to scaled dot product attention). Needs to have the arguments
33+
`query`, `key`, `value`, `attn_mask`, and `dropout_p`.
34+
dropout_p: dropout probability.
35+
bias: If True, then the projection onto the different heads is performed with bias.
36+
"""
37+
38+
def __init__(
39+
self,
40+
hidden_dim: int,
41+
n_heads: int,
42+
attention: Attention = None,
43+
dropout_p: float = 0,
44+
bias: bool = True,
45+
):
46+
super().__init__()
47+
48+
self.hidden_dim = hidden_dim
49+
self.n_heads = n_heads
50+
self.dropout_p = dropout_p
51+
self.bias = bias
52+
53+
if attention is None:
54+
attention = ScaledDotProductAttention()
55+
self.attention = attention
56+
57+
self.head_dim = hidden_dim // n_heads
58+
assert (
59+
self.head_dim * n_heads == hidden_dim
60+
), "hidden_dim must be divisible by n_heads"
61+
62+
# projection networks
63+
self.query_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
64+
self.key_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
65+
self.value_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
66+
self.out_project = nn.Linear(hidden_dim, hidden_dim, bias=bias)
67+
68+
def forward(
69+
self,
70+
query: torch.Tensor,
71+
key: torch.Tensor,
72+
value: torch.Tensor,
73+
attn_mask: torch.Tensor = None,
74+
) -> torch.Tensor:
75+
r"""Compute the attention scores.
76+
77+
Args:
78+
query: Query tensor of shape (batch_size, target_sequence_length, hidden_dim).
79+
key: Key tensor of shape (batch_size, source_sequence_length, hidden_dim).
80+
value: Value tensor of shape (batch_size, source_sequence_length, hidden_dim).
81+
attn_mask: Attention mask of shape (batch_size, target_sequence_length, source_sequence_length).
82+
83+
Returns:
84+
Attention scores of shape (batch_size, target_sequence_length, hidden_dim).
85+
"""
86+
assert query.ndim == key.ndim == value.ndim == 3, (
87+
"Query, key, and value need to have three dimensions (batch_size, ..., hidden_dim). This format ensures that"
88+
"the module can correctly apply the multi-head attention mechanism, which includes splitting embeddings "
89+
"into multiple heads, applying the internal attention implementation for each head, concatenating and "
90+
"projecting results, while ensuring that the attention mask is applied correctly."
91+
)
92+
assert (
93+
query.size(0) == key.size(0) == value.size(0)
94+
), "Batch size does not match for input tensors"
95+
assert (
96+
query.size(-1) == key.size(-1) == value.size(-1)
97+
), "Embedding/hidden dimension does not match for input tensors"
98+
99+
batch_size = query.size(0)
100+
src_len = key.size(1)
101+
tgt_len = query.size(1)
102+
103+
# project values
104+
query = self.query_project(query)
105+
key = self.key_project(key)
106+
value = self.value_project(value)
107+
108+
# form individual heads
109+
query = query.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
110+
key = key.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
111+
value = value.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
112+
113+
# reshape attention mask to match heads
114+
if attn_mask is not None:
115+
assert (
116+
attn_mask.size(0) == batch_size
117+
), "Attention mask batch size does not match input tensors."
118+
assert (
119+
attn_mask.size(1) == tgt_len
120+
), "First dimension of the attention mask needs to match target length."
121+
assert (
122+
attn_mask.size(2) == src_len
123+
), "Second dimension of the attention mask needs to match source length."
124+
125+
attn_mask = attn_mask.unsqueeze(1) # mask for a single head
126+
attn_mask = attn_mask.repeat(1, self.n_heads, 1, 1) # mask for every head
127+
128+
# perform attention
129+
attn_out = self.attention(
130+
query=query,
131+
key=key,
132+
value=value,
133+
attn_mask=attn_mask,
134+
)
135+
attn_out = attn_out.transpose(1, 2).reshape(batch_size, -1, self.hidden_dim)
136+
137+
# output projection
138+
return self.out_project(attn_out)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
`continuiti.networks.scaled_dot_product_attention`
3+
4+
Scaled dot product attention module.
5+
"""
6+
import torch
7+
8+
from .attention import Attention
9+
from torch.nn.functional import scaled_dot_product_attention
10+
11+
12+
class ScaledDotProductAttention(Attention):
13+
"""Scaled dot product attention module.
14+
15+
This module is a wrapper for the torch implementation of the scaled dot
16+
product attention mechanism as described in the paper "Attention Is All You
17+
Need" by Vaswani et al. (2017). This attention mechanism computes the
18+
attention weights based on the dot product of the query and key matrices,
19+
scaled by the square root of the dimension of the key vectors. The weights
20+
are then applied to the value vectors to obtain the final output.
21+
"""
22+
23+
def __init__(self, dropout_p: float = 0.0):
24+
super().__init__()
25+
self.dropout_p = dropout_p
26+
27+
def forward(
28+
self,
29+
query: torch.Tensor,
30+
key: torch.Tensor,
31+
value: torch.Tensor,
32+
attn_mask: torch.Tensor = None,
33+
) -> torch.Tensor:
34+
dropout_p = self.dropout_p if self.training else 0.0
35+
return scaled_dot_product_attention(
36+
query=query,
37+
key=key,
38+
value=value,
39+
attn_mask=attn_mask,
40+
dropout_p=dropout_p,
41+
)

0 commit comments

Comments
 (0)