@@ -107,7 +107,7 @@ def test_float16_dtype(self):
107
107
# output dtype for this layer should be float16.
108
108
self .assertEqual (outputs .dtype , "float16" )
109
109
110
- def test_positions_array (self ):
110
+ def test_positions_1d_array (self ):
111
111
rng = np .random .default_rng (0 )
112
112
x = rng .standard_normal (size = (1 , 2 , 1 , 16 )).astype (np .float32 )
113
113
positions = ops .cast ([0 , 0 ], "float32" )
@@ -152,9 +152,49 @@ def test_positions_array(self):
152
152
# fmt: on
153
153
154
154
layer = RotaryEmbedding ()
155
- got = layer (x , positions = positions )
155
+ output = layer (x , positions = positions )
156
156
157
- np .testing .assert_allclose (expected , ops .convert_to_numpy (got ))
157
+ np .testing .assert_allclose (expected , ops .convert_to_numpy (output ))
158
+
159
+ def test_positions_2d_array (self ):
160
+ rng = np .random .default_rng (0 )
161
+ x = rng .standard_normal (size = (2 , 2 , 1 , 16 )).astype (np .float32 )
162
+ positions = ops .cast ([[0 , 0 ], [0 , 1 ]], "float32" )
163
+
164
+ # fmt: off
165
+ expected = np .array (
166
+ [
167
+ [
168
+ [[0.12573022 , - 0.13210486 , 0.64042264 , 0.10490011 ,
169
+ - 0.5356694 , 0.36159506 , 1.304 , 0.94708097 ,
170
+ - 0.70373523 , - 1.2654215 , - 0.62327445 , 0.04132598 ,
171
+ - 2.3250308 , - 0.21879166 , - 1.245911 , - 0.7322674 ]],
172
+ [[- 0.544259 , - 0.31630015 , 0.41163054 , 1.0425134 ,
173
+ - 0.12853466 , 1.3664634 , - 0.6651947 , 0.35151008 ,
174
+ 0.90347016 , 0.0940123 , - 0.7434993 , - 0.9217254 ,
175
+ - 0.45772582 , 0.22019513 , - 1.0096182 , - 0.20917557 ]
176
+ ]
177
+ ],
178
+ [
179
+ [[- 0.159225017 , 0.540845573 , 0.214659125 , 0.355372697 ,
180
+ - 0.653828621 , - 0.129613638 , 0.783975482 , 1.49343109 ,
181
+ - 1.25906551 , 1.51392376 , 1.34587538 , 0.781311393 ,
182
+ 0.264455616 , - 0.313922822 , 1.45802069 , 1.96025836 ]],
183
+ [[0.611709595 , 1.03343689 , 0.47380957 , - 1.18679309 ,
184
+ - 8.96309502e-05 , 0.660170913 , - 1.29010022 , 0.395278841 ,
185
+ 1.74827969 , 1.07050526 , - 1.14252377 , - 0.699575782 ,
186
+ - 0.436457992 , - 1.1677202 , 1.73807859 , - 0.495785743 ]
187
+ ]
188
+ ]
189
+ ],
190
+ dtype = np .float32
191
+ ) # noqa
192
+ # fmt: on
193
+
194
+ layer = RotaryEmbedding ()
195
+ output = layer (x , positions = positions )
196
+
197
+ self .assertAllClose (output , expected , rtol = 1e-5 , atol = 1e-5 )
158
198
159
199
def test_rope_scaling (self ):
160
200
# Reference values computed from Huggingface llama implementation
0 commit comments