44import numpy as np
55import torch
66
7-
87def create_random_jax_torch (* shape , dtype = np .float32 ):
98 """Create random array and return both JAX and PyTorch versions.
109
@@ -19,6 +18,26 @@ def create_random_jax_torch(*shape, dtype=np.float32):
1918 return jnp .array (np_array ), torch .from_numpy (np_array )
2019
2120
21+ def split_into_patches (x , temporal_patch_size , patch_size ):
22+ """Split a 5D tensor into patches for PyTorch vision encoder input.
23+
24+ Converts from full image format (batch, channels, temporal, height, width) to
25+ patch format (num_patches, channels, temporal_patch_size, patch_size, patch_size).
26+
27+ Returns:
28+ Tensor of shape (num_patches, channels, temporal_patch_size, patch_size, patch_size)
29+ where num_patches = (temporal//temporal_patch_size) * (height//patch_size) * (width//patch_size)
30+ """
31+ B , C , T , H , W = x .shape
32+ assert T % temporal_patch_size == 0 , f"Temporal dimension { T } must be divisible by { temporal_patch_size } "
33+ assert H % patch_size == 0 , f"Height { H } must be divisible by { patch_size } "
34+ assert W % patch_size == 0 , f"Width { W } must be divisible by { patch_size } "
35+
36+ x = x .reshape (B , C , T , H // patch_size , patch_size , W // patch_size , patch_size )
37+ x = x .permute (0 , 3 , 5 , 1 , 2 , 4 , 6 ) # (B, H//patch_size, W//patch_size, C, T, patch_size, patch_size)
38+ return x .reshape (- 1 , C , T , patch_size , patch_size )
39+
40+
2241def assert_all_close_jax_torch (jax_tensor , torch_tensor , rtol , atol , error_msg = "" ):
2342 """Compare JAX and PyTorch tensors for numerical closeness.
2443
@@ -295,8 +314,6 @@ def copy_maxtext_encoder_weights(torch_encoder, maxtext_encoder):
295314
296315
297316# Vision-specific weight copying utilities
298-
299-
300317def copy_conv3d_weights (torch_conv , jax_conv ):
301318 """Copy weights from PyTorch Conv3d to JAX nnx.Conv (3D)."""
302319 # PyTorch Conv3d: (out_channels, in_channels, kD, kH, kW)
@@ -332,8 +349,6 @@ def copy_vision_encoder_weights(torch_encoder, jax_encoder):
332349 torch_encoder: PyTorch Qwen3OmniMoeVisionEncoder
333350 jax_encoder: JAX Qwen3OmniMoeVisionEncoder
334351 """
335- import jax .numpy as jnp
336-
337352 # Copy patch embedding
338353 copy_patch_embed_weights (torch_encoder .patch_embed , jax_encoder .patch_embed )
339354
@@ -362,8 +377,6 @@ def copy_vision_encoder_weights(torch_encoder, jax_encoder):
362377
363378
364379# Audio-specific utilities
365-
366-
367380def create_block_diagonal_attention_mask (
368381 cu_seqlens , dtype
369382):
0 commit comments