Skip to content

Commit ec906a3

Browse files
authored
Allow passing flexible positions to positional embedding layers (#2369)
1 parent fef3bc3 commit ec906a3

File tree

7 files changed

+137
-24
lines changed

7 files changed

+137
-24
lines changed

keras_hub/src/layers/modeling/position_embedding.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ class PositionEmbedding(keras.layers.Layer):
3131
start_index: An integer or integer tensor. The starting position to
3232
compute the position embedding from. This is useful during cached
3333
decoding, where each position is predicted separately in a loop.
34+
positions: Tensor of shape `(sequence_length,)` or
35+
`(batch_size, sequence_length)`. Custom positions for the input
36+
sequence. If specified, this tensor will be used to
37+
compute the position embedding, and the `start_index` argument will
38+
be ignored. This is useful for cases with non-standard positions.
3439
3540
Example:
3641
@@ -91,18 +96,28 @@ def build(self, inputs_shape):
9196
)
9297
self.built = True
9398

94-
def call(self, inputs, start_index=0):
99+
def call(self, inputs, start_index=0, positions=None):
95100
shape = ops.shape(inputs)
96101
feature_length = shape[-1]
97102
sequence_length = shape[-2]
98103
# trim to match the length of the input sequence, which might be less
99104
# than the sequence_length of the layer.
100105
position_embeddings = ops.convert_to_tensor(self.position_embeddings)
101-
position_embeddings = ops.slice(
102-
position_embeddings,
103-
(start_index, 0),
104-
(sequence_length, feature_length),
105-
)
106+
if positions is None:
107+
position_embeddings = ops.slice(
108+
position_embeddings,
109+
(start_index, 0),
110+
(sequence_length, feature_length),
111+
)
112+
else:
113+
# Take care of unbatched `positions`.
114+
if len(ops.shape(positions)) == 1:
115+
positions = ops.expand_dims(positions, axis=0)
116+
117+
position_embeddings = ops.take(
118+
position_embeddings, positions, axis=0
119+
)
120+
106121
return ops.broadcast_to(position_embeddings, shape)
107122

108123
def compute_output_shape(self, input_shape):

keras_hub/src/layers/modeling/position_embedding_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,24 @@ def test_start_index(self):
141141
sequential_output, (0, i, 0), parial_output
142142
)
143143
self.assertAllClose(full_output, sequential_output)
144+
145+
def test_positions(self):
146+
batch_size, seq_length, feature_size = 2, 4, 5
147+
data = random.uniform(shape=(batch_size, seq_length, feature_size))
148+
positions = np.array([[0, 0, 1, 2], [1, 2, 3, 0]])
149+
150+
layer = PositionEmbedding(seq_length)
151+
output = layer(data, positions=positions)
152+
153+
expected_output = []
154+
for b_idx in range(batch_size):
155+
for s_idx in range(seq_length):
156+
actual_position = positions[b_idx, s_idx]
157+
expected_output.append(
158+
layer.position_embeddings.numpy()[actual_position]
159+
)
160+
161+
expected_output = np.reshape(
162+
np.array(expected_output), (batch_size, seq_length, feature_size)
163+
)
164+
self.assertAllClose(output, expected_output, rtol=1e-5, atol=1e-5)

keras_hub/src/layers/modeling/rotary_embedding.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class RotaryEmbedding(keras.layers.Layer):
3737
start_index: An integer or integer tensor. The starting position to
3838
compute the rotary embedding from. This is useful during cached
3939
decoding, where each position is predicted separately in a loop.
40+
positions: Tensor of shape `(sequence_length,)` or
41+
`(batch_size, sequence_length)`. Custom positions for the input
42+
sequence. If specified, this tensor will be used to
43+
compute the rotary embedding, and the `start_index` argument will
44+
be ignored. This is useful for cases with non-standard positions.
4045
4146
Examples:
4247
@@ -76,6 +81,11 @@ def __init__(
7681
self.built = True
7782

7883
def call(self, inputs, start_index=0, positions=None):
84+
# Take care of unbatched `positions`.
85+
if positions is not None:
86+
if len(ops.shape(positions)) == 1:
87+
positions = ops.expand_dims(positions, axis=0)
88+
7989
inputs = ops.moveaxis(
8090
inputs, (self.feature_axis, self.sequence_axis), (-1, 1)
8191
)
@@ -103,6 +113,7 @@ def _compute_positions(self, inputs, start_index=0):
103113
return positions + ops.cast(start_index, dtype="float32")
104114

105115
def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
116+
batch_axis = 0
106117
feature_axis = len(inputs.shape) - 1
107118
sequence_axis = 1
108119

@@ -111,21 +122,20 @@ def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
111122

112123
if positions is None:
113124
positions = self._compute_positions(inputs, start_index)
125+
positions = ops.expand_dims(positions, axis=batch_axis)
114126
else:
115127
positions = ops.cast(positions, "float32")
116-
117128
positions = positions / ops.cast(self.scaling_factor, "float32")
118-
freq = ops.einsum("i,j->ij", positions, inverse_freq)
129+
130+
freq = ops.einsum("bi,j->bij", positions, inverse_freq)
131+
119132
embedding = ops.stack((freq, freq), axis=-2)
120133
embedding = ops.reshape(
121134
embedding, (*ops.shape(freq)[:-1], ops.shape(freq)[-1] * 2)
122135
)
123136

124-
# Reshape the embedding to be broadcastable with input shape.
125-
if feature_axis < sequence_axis:
126-
embedding = ops.transpose(embedding)
127137
for axis in range(len(inputs.shape)):
128-
if axis != sequence_axis and axis != feature_axis:
138+
if axis not in (batch_axis, sequence_axis, feature_axis):
129139
embedding = ops.expand_dims(embedding, axis)
130140

131141
cos_emb = ops.cast(ops.cos(embedding), self.compute_dtype)

keras_hub/src/layers/modeling/rotary_embedding_test.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_float16_dtype(self):
107107
# output dtype for this layer should be float16.
108108
self.assertEqual(outputs.dtype, "float16")
109109

110-
def test_positions_array(self):
110+
def test_positions_1d_array(self):
111111
rng = np.random.default_rng(0)
112112
x = rng.standard_normal(size=(1, 2, 1, 16)).astype(np.float32)
113113
positions = ops.cast([0, 0], "float32")
@@ -152,9 +152,49 @@ def test_positions_array(self):
152152
# fmt: on
153153

154154
layer = RotaryEmbedding()
155-
got = layer(x, positions=positions)
155+
output = layer(x, positions=positions)
156156

157-
np.testing.assert_allclose(expected, ops.convert_to_numpy(got))
157+
np.testing.assert_allclose(expected, ops.convert_to_numpy(output))
158+
159+
def test_positions_2d_array(self):
160+
rng = np.random.default_rng(0)
161+
x = rng.standard_normal(size=(2, 2, 1, 16)).astype(np.float32)
162+
positions = ops.cast([[0, 0], [0, 1]], "float32")
163+
164+
# fmt: off
165+
expected = np.array(
166+
[
167+
[
168+
[[0.12573022, -0.13210486, 0.64042264, 0.10490011,
169+
-0.5356694, 0.36159506, 1.304, 0.94708097,
170+
-0.70373523, -1.2654215, -0.62327445, 0.04132598,
171+
-2.3250308, -0.21879166, -1.245911, -0.7322674]],
172+
[[-0.544259, -0.31630015, 0.41163054, 1.0425134,
173+
-0.12853466, 1.3664634, -0.6651947, 0.35151008,
174+
0.90347016, 0.0940123, -0.7434993, -0.9217254,
175+
-0.45772582, 0.22019513, -1.0096182, -0.20917557]
176+
]
177+
],
178+
[
179+
[[-0.159225017, 0.540845573, 0.214659125, 0.355372697,
180+
-0.653828621, -0.129613638, 0.783975482, 1.49343109,
181+
-1.25906551, 1.51392376, 1.34587538, 0.781311393,
182+
0.264455616, -0.313922822, 1.45802069, 1.96025836]],
183+
[[0.611709595, 1.03343689, 0.47380957, -1.18679309,
184+
-8.96309502e-05, 0.660170913, -1.29010022, 0.395278841,
185+
1.74827969, 1.07050526, -1.14252377, -0.699575782,
186+
-0.436457992, -1.1677202, 1.73807859, -0.495785743]
187+
]
188+
]
189+
],
190+
dtype=np.float32
191+
) # noqa
192+
# fmt: on
193+
194+
layer = RotaryEmbedding()
195+
output = layer(x, positions=positions)
196+
197+
self.assertAllClose(output, expected, rtol=1e-5, atol=1e-5)
158198

159199
def test_rope_scaling(self):
160200
# Reference values computed from Huggingface llama implementation

keras_hub/src/layers/modeling/sine_position_encoding.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class SinePositionEncoding(keras.layers.Layer):
3030
start_index: An integer or integer tensor. The starting position to
3131
compute the encoding from. This is useful during cached decoding,
3232
where each position is predicted separately in a loop.
33+
positions: Tensor of shape `(sequence_length,)` or
34+
`(batch_size, sequence_length)`. Custom positions for the input
35+
sequence. If specified, this tensor will be used to
36+
compute the position embedding, and the `start_index` argument will
37+
be ignored. This is useful for cases with non-standard positions.
3338
3439
Example:
3540
```python
@@ -58,27 +63,35 @@ def __init__(
5863
self.max_wavelength = max_wavelength
5964
self.built = True
6065

61-
def call(self, inputs, start_index=0):
66+
def call(self, inputs, start_index=0, positions=None):
6267
shape = ops.shape(inputs)
6368
seq_length = shape[-2]
6469
hidden_size = shape[-1]
65-
positions = ops.arange(seq_length)
66-
positions = ops.cast(positions + start_index, self.compute_dtype)
70+
71+
if positions is None:
72+
positions = ops.arange(seq_length)
73+
positions = ops.cast(positions + start_index, self.compute_dtype)
74+
75+
# Take care of unbatched `positions`.
76+
if len(ops.shape(positions)) == 1:
77+
positions = ops.expand_dims(positions, axis=0)
78+
6779
min_freq = ops.cast(1 / self.max_wavelength, dtype=self.compute_dtype)
6880
timescales = ops.power(
6981
min_freq,
7082
ops.cast(2 * (ops.arange(hidden_size) // 2), self.compute_dtype)
7183
/ ops.cast(hidden_size, self.compute_dtype),
7284
)
73-
angles = ops.expand_dims(positions, 1) * ops.expand_dims(timescales, 0)
85+
angles = ops.einsum("bi,j->bij", positions, timescales)
86+
7487
# even indices are sine, odd are cosine
7588
cos_mask = ops.cast(ops.arange(hidden_size) % 2, self.compute_dtype)
7689
sin_mask = 1 - cos_mask
77-
# embedding shape is [seq_length, hidden_size]
78-
positional_encodings = (
79-
ops.sin(angles) * sin_mask + ops.cos(angles) * cos_mask
80-
)
8190

91+
# embedding shape is `[bsz (or 1), seq_length, hidden_size]`.
92+
positional_encodings = ops.einsum(
93+
"bij,j->bij", ops.sin(angles), sin_mask
94+
) + ops.einsum("bij,j->bij", ops.cos(angles), cos_mask)
8295
return ops.broadcast_to(positional_encodings, shape)
8396

8497
def get_config(self):

keras_hub/src/layers/modeling/sine_position_encoding_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,16 @@ def test_start_index(self):
9494
sequential_output, (0, i, 0), parial_output
9595
)
9696
self.assertAllClose(full_output, sequential_output)
97+
98+
def test_positions(self):
99+
batch_size, seq_length, feature_size = 2, 2, 4
100+
data = random.uniform(shape=(batch_size, seq_length, feature_size))
101+
positions = ops.array([[0, 1], [1, 0]])
102+
103+
layer = SinePositionEncoding()
104+
output = layer(data, positions=positions)
105+
106+
pos_0 = [0.0, 1.0, 0.0, 1.0]
107+
pos_1 = [0.84147, 0.54030, 0.009999, 0.99995]
108+
expected = [[pos_0, pos_1], [pos_1, pos_0]]
109+
self.assertAllClose(expected, output, rtol=1e-5, atol=1e-5)

keras_hub/src/layers/modeling/token_and_position_embedding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,12 @@ def get_config(self):
120120
)
121121
return config
122122

123-
def call(self, inputs, start_index=0):
123+
def call(self, inputs, start_index=0, positions=None):
124124
embedded_tokens = self.token_embedding(inputs)
125125
embedded_positions = self.position_embedding(
126126
embedded_tokens,
127127
start_index=start_index,
128+
positions=positions,
128129
)
129130
outputs = embedded_tokens + embedded_positions
130131
return outputs

0 commit comments

Comments
 (0)