From 292887e5cbba213222d9a9bf30fea02316a03087 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mehmet=20Ali=20=C3=96zer?= Date: Sun, 15 Jun 2025 17:09:33 +0300 Subject: [PATCH] Fix: Downsample before embedding in multi-resolution mode and fix fps import location --- .../autoencoders/michelangelo_autoencoder.py | 46 +++++++++++-------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py b/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py index 130cc84..6f1a19f 100755 --- a/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py +++ b/step1x3d_geometry/models/autoencoders/michelangelo_autoencoder.py @@ -325,6 +325,34 @@ def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None): bs, N, D = pc.shape + # Apply multi-resolution sampling FIRST (if enabled) + if self.use_multi_reso: + from torch_cluster import fps + + resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[0] + + if resolution != N: + flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3 + batch = torch.arange(bs).to(pc.device) # 103 + batch = torch.repeat_interleave(batch, N) # bs*N. 421888 + pos = flattened.to(torch.float16) + ratio = 1.0 * resolution / N # 0.0625 + idx = fps(pos, batch, ratio=ratio) # 26368 + + # Downsample all inputs + pc = pc.view(bs * N, -1)[idx].view(bs, -1, D) + + if feats is not None: + feats = feats.view(bs * N, -1)[idx].view(bs, -1, feats.shape[-1]) + if sharp_pc is not None: + sharp_pc = sharp_pc.view(bs * N, -1)[idx].view(bs, -1, D) + if sharp_feat is not None: + sharp_feat = sharp_feat.view(bs * N, -1)[idx].view(bs, -1, sharp_feat.shape[-1]) + + bs, N, D = pc.shape + + # Continue with the original number of points + # if no downsampling is applied or after multi-resolution sampling data = self.embedder(pc) if feats is not None: if self.embed_point_feats: @@ -340,24 +368,6 @@ def _forward(self, pc, feats, sharp_pc=None, sharp_feat=None): sharp_data = torch.cat([sharp_data, sharp_feat], dim=-1) sharp_data = self.input_proj_sharp(sharp_data) - if self.use_multi_reso: - resolution = random.choice(self.resolutions, size=1, p=self.sampling_prob)[ - 0 - ] - - if resolution != N: - flattened = pc.view(bs * N, D) # bs*N, 64. 103,4096,3 -> 421888,3 - batch = torch.arange(bs).to(pc.device) # 103 - batch = torch.repeat_interleave(batch, N) # bs*N. 421888 - pos = flattened.to(torch.float16) - ratio = 1.0 * resolution / N # 0.0625 - idx = fps(pos, batch, ratio=ratio) # 26368 - pc = pc.view(bs * N, -1)[idx].view(bs, -1, D) - bs, N, D = feats.shape - flattened1 = feats.view(bs * N, D) - feats = flattened1.view(bs * N, -1)[idx].view(bs, -1, D) - bs, N, D = pc.shape - if self.use_downsample: ###### fps from torch_cluster import fps