Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None):

bs, N, D = pc.shape

# Apply multi-resolution sampling FIRST (if enabled)
if self.use_multi_reso:
from torch_cluster import fps

resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[0]

if resolution != N:
flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3
batch = torch.arange(bs).to(pc.device) # 103
batch = torch.repeat_interleave(batch, N) # bs*N. 421888
pos = flattened.to(torch.float16)
ratio = 1.0 * resolution / N # 0.0625
idx = fps(pos, batch, ratio=ratio) # 26368

# Downsample all inputs
pc = pc.view(bs * N, -1)[idx].view(bs, -1, D)

if feats is not None:
feats = feats.view(bs * N, -1)[idx].view(bs, -1, feats.shape[-1])
if sharp_pc is not None:
sharp_pc = sharp_pc.view(bs * N, -1)[idx].view(bs, -1, D)
if sharp_feat is not None:
sharp_feat = sharp_feat.view(bs * N, -1)[idx].view(bs, -1, sharp_feat.shape[-1])

bs, N, D = pc.shape

# Continue with the original number of points
# if no downsampling is applied or after multi-resolution sampling
data = self.embedder(pc)
if feats is not None:
if self.embed_point_feats:
Expand All @@ -340,24 +368,6 @@ def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None):
sharp_data = torch.cat([sharp_data, sharp_feat], dim=-1)
sharp_data = self.input_proj_sharp(sharp_data)

if self.use_multi_reso:
resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[
0
]

if resolution != N:
flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3
batch = torch.arange(bs).to(pc.device) # 103
batch = torch.repeat_interleave(batch, N) # bs*N. 421888
pos = flattened.to(torch.float16)
ratio = 1.0 * resolution / N # 0.0625
idx = fps(pos, batch, ratio=ratio) # 26368
pc = pc.view(bs * N, -1)[idx].view(bs, -1, D)
bs, N, D = feats.shape
flattened1 = feats.view(bs * N, D)
feats = flattened1.view(bs * N, -1)[idx].view(bs, -1, D)
bs, N, D = pc.shape

if self.use_downsample:
###### fps
from torch_cluster import fps
Expand Down