Skip to content

Commit 51ee022

Browse files
authored
Merge pull request #112 from Vincentqyw/main
Fix MPS compatibility by using float32 in positional embeddings
2 parents 588a0a2 + 65c1f4e commit 51ee022

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

vggt/heads/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 10
4545
- emb: The generated 1D positional embedding.
4646
"""
4747
assert embed_dim % 2 == 0
48-
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
48+
device = pos.device
49+
omega = torch.arange(embed_dim // 2, dtype=torch.float32 if device.type == "mps" else torch.double, device=device)
4950
omega /= embed_dim / 2.0
5051
omega = 1.0 / omega_0**omega # (D/2,)
5152

0 commit comments

Comments
 (0)