Skip to content

Commit 65c1f4e

Browse files
committed
fix: use float32 in pos embed for MPS compatibility
1 parent 350abfc commit 65c1f4e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vggt/heads/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 10
4646
"""
4747
assert embed_dim % 2 == 0
4848
device = pos.device
49-
omega = torch.arange(embed_dim // 2, dtype=torch.float if device.type == "mps" else torch.double, device=device)
49+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
5050
omega /= embed_dim / 2.0
5151
omega = 1.0 / omega_0**omega # (D/2,)
5252

0 commit comments

Comments
 (0)