Skip to content

Commit b3ef961

Browse files
committed
Qwen vision encoder tests
1 parent c5f29b8 commit b3ef961

File tree

2 files changed

+551
-7
lines changed

2 files changed

+551
-7
lines changed

tests/multimodal_test_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55
import torch
66

7-
87
def create_random_jax_torch(*shape, dtype=np.float32):
98
"""Create random array and return both JAX and PyTorch versions.
109
@@ -19,6 +18,26 @@ def create_random_jax_torch(*shape, dtype=np.float32):
1918
return jnp.array(np_array), torch.from_numpy(np_array)
2019

2120

21+
def split_into_patches(x, temporal_patch_size, patch_size):
22+
"""Split a 5D tensor into patches for PyTorch vision encoder input.
23+
24+
Converts from full image format (batch, channels, temporal, height, width) to
25+
patch format (num_patches, channels, temporal_patch_size, patch_size, patch_size).
26+
27+
Returns:
28+
Tensor of shape (num_patches, channels, temporal_patch_size, patch_size, patch_size)
29+
where num_patches = (temporal//temporal_patch_size) * (height//patch_size) * (width//patch_size)
30+
"""
31+
B, C, T, H, W = x.shape
32+
assert T % temporal_patch_size == 0, f"Temporal dimension {T} must be divisible by {temporal_patch_size}"
33+
assert H % patch_size == 0, f"Height {H} must be divisible by {patch_size}"
34+
assert W % patch_size == 0, f"Width {W} must be divisible by {patch_size}"
35+
36+
x = x.reshape(B, C, T, H // patch_size, patch_size, W // patch_size, patch_size)
37+
x = x.permute(0, 3, 5, 1, 2, 4, 6) # (B, H//patch_size, W//patch_size, C, T, patch_size, patch_size)
38+
return x.reshape(-1, C, T, patch_size, patch_size)
39+
40+
2241
def assert_all_close_jax_torch(jax_tensor, torch_tensor, rtol, atol, error_msg=""):
2342
"""Compare JAX and PyTorch tensors for numerical closeness.
2443
@@ -295,8 +314,6 @@ def copy_maxtext_encoder_weights(torch_encoder, maxtext_encoder):
295314

296315

297316
# Vision-specific weight copying utilities
298-
299-
300317
def copy_conv3d_weights(torch_conv, jax_conv):
301318
"""Copy weights from PyTorch Conv3d to JAX nnx.Conv (3D)."""
302319
# PyTorch Conv3d: (out_channels, in_channels, kD, kH, kW)
@@ -332,8 +349,6 @@ def copy_vision_encoder_weights(torch_encoder, jax_encoder):
332349
torch_encoder: PyTorch Qwen3OmniMoeVisionEncoder
333350
jax_encoder: JAX Qwen3OmniMoeVisionEncoder
334351
"""
335-
import jax.numpy as jnp
336-
337352
# Copy patch embedding
338353
copy_patch_embed_weights(torch_encoder.patch_embed, jax_encoder.patch_embed)
339354

@@ -362,8 +377,6 @@ def copy_vision_encoder_weights(torch_encoder, jax_encoder):
362377

363378

364379
# Audio-specific utilities
365-
366-
367380
def create_block_diagonal_attention_mask(
368381
cu_seqlens, dtype
369382
):

0 commit comments

Comments
 (0)