Skip to content

Commit 47d4231

Browse files
committed
seq length in rope
1 parent f6d77c3 commit 47d4231

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

labml_nn/transformers/rope/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ def forward(self, x: torch.Tensor):
168168
# Cache $\cos$ and $\sin$ values
169169
self._build_cache(x)
170170

171+
# Sequence length
172+
seq_len = x.shape[0]
173+
171174
# Split the features, we can choose to apply rotary embeddings only to a partial set of features.
172175
x_rope, x_pass = x[..., :self.d], x[..., self.d:]
173176

@@ -185,7 +188,7 @@ def forward(self, x: torch.Tensor):
185188
# \end{align}
186189
#
187190
# for $i \in {1, 2, ..., \frac{d}{2}}$
188-
x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half_x * self.sin_cached[:x.shape[0]])
191+
x_rope = (x_rope * self.cos_cached[:seq_len]) + (neg_half_x * self.sin_cached[:seq_len])
189192

190193
#
191194
return torch.cat((x_rope, x_pass), dim=-1)

0 commit comments

Comments
 (0)