diff --git a/README.md b/README.md index 220db79..80902a6 100644 --- a/README.md +++ b/README.md @@ -331,4 +331,4 @@ EmbodiedGen builds upon the following amazing projects and models: ## ⚖️ License -This project is licensed under the [Apache License 2.0](LICENSE). See the `LICENSE` file for details. +This project is licensed under the [Apache License 2.0](docs/LICENSE). See the `LICENSE` file for details. diff --git a/apps/app_style.py b/apps/app_style.py index cf27056..a552f9f 100644 --- a/apps/app_style.py +++ b/apps/app_style.py @@ -4,7 +4,7 @@ lighting_css = """ """ diff --git a/apps/common.py b/apps/common.py index c3891f8..3fcac18 100644 --- a/apps/common.py +++ b/apps/common.py @@ -32,8 +32,9 @@ from easydict import EasyDict as edict from PIL import Image from embodied_gen.data.backproject_v2 import entrypoint as backproject_api +from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3 from embodied_gen.data.differentiable_render import entrypoint as render_api -from embodied_gen.data.utils import trellis_preprocess, zip_files +from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.segment_model import ( @@ -131,8 +132,8 @@ def build_covariance_from_scaling_rotation( Gaussian.setup_functions = patched_setup_functions -DELIGHT = DelightingModel() -IMAGESR_MODEL = ImageRealESRGAN(outscale=4) +# DELIGHT = DelightingModel() +# IMAGESR_MODEL = ImageRealESRGAN(outscale=4) # IMAGESR_MODEL = ImageStableSR() if os.getenv("GRADIO_APP") == "imageto3d": RBG_REMOVER = RembgRemover() @@ -169,6 +170,8 @@ def build_covariance_from_scaling_rotation( ) os.makedirs(TMP_DIR, exist_ok=True) elif os.getenv("GRADIO_APP") == "texture_edit": + DELIGHT = DelightingModel() + IMAGESR_MODEL = ImageRealESRGAN(outscale=4) PIPELINE_IP = build_texture_gen_pipe( base_ckpt_dir="./weights", ip_adapt_scale=0.7, @@ -205,7 +208,7 @@ def preprocess_image_fn( elif isinstance(image, np.ndarray): image = Image.fromarray(image) - image_cache = image.copy().resize((512, 512)) + image_cache = resize_pil(image.copy(), 1024) bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER image = bg_remover(image) @@ -221,7 +224,7 @@ def preprocess_sam_image_fn( image = Image.fromarray(image) sam_image = SAM_PREDICTOR.preprocess_image(image) - image_cache = Image.fromarray(sam_image).resize((512, 512)) + image_cache = sam_image.copy() SAM_PREDICTOR.predictor.set_image(sam_image) return sam_image, image_cache @@ -512,6 +515,60 @@ def extract_3d_representations_v2( return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path +def extract_3d_representations_v3( + state: dict, + enable_delight: bool, + texture_size: int, + req: gr.Request, +): + output_root = TMP_DIR + user_dir = os.path.join(output_root, str(req.session_hash)) + gs_model, mesh_model = unpack_state(state, device="cpu") + + filename = "sample" + gs_path = os.path.join(user_dir, f"{filename}_gs.ply") + gs_model.save_ply(gs_path) + + # Rotate mesh and GS by 90 degrees around Z-axis. + rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] + gs_add_rot = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] + mesh_add_rot = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + # Addtional rotation for GS to align mesh. + gs_rot = np.array(gs_add_rot) @ np.array(rot_matrix) + pose = GaussianOperator.trans_to_quatpose(gs_rot) + aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") + GaussianOperator.resave_ply( + in_ply=gs_path, + out_ply=aligned_gs_path, + instance_pose=pose, + device="cpu", + ) + + mesh = trimesh.Trimesh( + vertices=mesh_model.vertices.cpu().numpy(), + faces=mesh_model.faces.cpu().numpy(), + ) + mesh.vertices = mesh.vertices @ np.array(mesh_add_rot) + mesh.vertices = mesh.vertices @ np.array(rot_matrix) + + mesh_obj_path = os.path.join(user_dir, f"{filename}.obj") + mesh.export(mesh_obj_path) + + mesh = backproject_api_v3( + gs_path=aligned_gs_path, + mesh_path=mesh_obj_path, + output_path=mesh_obj_path, + skip_fix_mesh=False, + texture_size=texture_size, + ) + + mesh_glb_path = os.path.join(user_dir, f"{filename}.glb") + mesh.export(mesh_glb_path) + + return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path + + def extract_urdf( gs_path: str, mesh_obj_path: str, diff --git a/apps/image_to_3d.py b/apps/image_to_3d.py index 14e4931..d8c1681 100644 --- a/apps/image_to_3d.py +++ b/apps/image_to_3d.py @@ -27,7 +27,7 @@ VERSION, active_btn_by_content, end_session, - extract_3d_representations_v2, + extract_3d_representations_v3, extract_urdf, get_seed, image_to_3d, @@ -179,17 +179,17 @@ ) generate_btn = gr.Button( - "🚀 1. Generate(~0.5 mins)", + "🚀 1. Generate(~2 mins)", variant="primary", interactive=False, ) model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) - with gr.Row(): - extract_rep3d_btn = gr.Button( - "🔍 2. Extract 3D Representation(~2 mins)", - variant="primary", - interactive=False, - ) + # with gr.Row(): + # extract_rep3d_btn = gr.Button( + # "🔍 2. Extract 3D Representation(~2 mins)", + # variant="primary", + # interactive=False, + # ) with gr.Accordion( label="Enter Asset Attributes(optional)", open=False ): @@ -207,7 +207,7 @@ ) with gr.Row(): extract_urdf_btn = gr.Button( - "🧩 3. Extract URDF with physics(~1 mins)", + "🧩 2. Extract URDF with physics(~1 mins)", variant="primary", interactive=False, ) @@ -230,7 +230,7 @@ ) with gr.Row(): download_urdf = gr.DownloadButton( - label="⬇️ 4. Download URDF", + label="⬇️ 3. Download URDF", variant="primary", interactive=False, ) @@ -326,7 +326,7 @@ image_prompt.change( lambda: tuple( [ - gr.Button(interactive=False), + # gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), None, @@ -344,7 +344,7 @@ ] ), outputs=[ - extract_rep3d_btn, + # extract_rep3d_btn, extract_urdf_btn, download_urdf, model_output_gs, @@ -375,7 +375,7 @@ image_prompt_sam.change( lambda: tuple( [ - gr.Button(interactive=False), + # gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), None, @@ -394,7 +394,7 @@ ] ), outputs=[ - extract_rep3d_btn, + # extract_rep3d_btn, extract_urdf_btn, download_urdf, model_output_gs, @@ -447,12 +447,7 @@ ], outputs=[output_buf, video_output], ).success( - lambda: gr.Button(interactive=True), - outputs=[extract_rep3d_btn], - ) - - extract_rep3d_btn.click( - extract_3d_representations_v2, + extract_3d_representations_v3, inputs=[ output_buf, project_delight, @@ -495,4 +490,4 @@ if __name__ == "__main__": - demo.launch() + demo.launch(server_port=8081) diff --git a/apps/text_to_3d.py b/apps/text_to_3d.py index 8614b48..8c9012c 100644 --- a/apps/text_to_3d.py +++ b/apps/text_to_3d.py @@ -27,7 +27,7 @@ VERSION, active_btn_by_text_content, end_session, - extract_3d_representations_v2, + extract_3d_representations_v3, extract_urdf, get_cached_image, get_seed, @@ -178,17 +178,17 @@ ) generate_btn = gr.Button( - "🚀 2. Generate 3D(~0.5 mins)", + "🚀 2. Generate 3D(~2 mins)", variant="primary", interactive=False, ) model_output_obj = gr.Textbox(label="raw mesh .obj", visible=False) - with gr.Row(): - extract_rep3d_btn = gr.Button( - "🔍 3. Extract 3D Representation(~1 mins)", - variant="primary", - interactive=False, - ) + # with gr.Row(): + # extract_rep3d_btn = gr.Button( + # "🔍 3. Extract 3D Representation(~1 mins)", + # variant="primary", + # interactive=False, + # ) with gr.Accordion( label="Enter Asset Attributes(optional)", open=False ): @@ -206,13 +206,13 @@ ) with gr.Row(): extract_urdf_btn = gr.Button( - "🧩 4. Extract URDF with physics(~1 mins)", + "🧩 3. Extract URDF with physics(~1 mins)", variant="primary", interactive=False, ) with gr.Row(): download_urdf = gr.DownloadButton( - label="⬇️ 5. Download URDF", + label="⬇️ 4. Download URDF", variant="primary", interactive=False, ) @@ -336,7 +336,7 @@ generate_img_btn.click( lambda: tuple( [ - gr.Button(interactive=False), + # gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False), @@ -358,7 +358,7 @@ ] ), outputs=[ - extract_rep3d_btn, + # extract_rep3d_btn, extract_urdf_btn, download_urdf, generate_btn, @@ -428,12 +428,7 @@ ], outputs=[output_buf, video_output], ).success( - lambda: gr.Button(interactive=True), - outputs=[extract_rep3d_btn], - ) - - extract_rep3d_btn.click( - extract_3d_representations_v2, + extract_3d_representations_v3, inputs=[ output_buf, project_delight, @@ -476,4 +471,4 @@ if __name__ == "__main__": - demo.launch() + demo.launch(server_port=8082) diff --git a/apps/texture_edit.py b/apps/texture_edit.py index 5c60b73..722ce6c 100644 --- a/apps/texture_edit.py +++ b/apps/texture_edit.py @@ -381,4 +381,4 @@ def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox): if __name__ == "__main__": - demo.launch() + demo.launch(server_port=8083) diff --git a/apps/visualize_asset.py b/apps/visualize_asset.py index 4d3257c..5e9b94b 100644 --- a/apps/visualize_asset.py +++ b/apps/visualize_asset.py @@ -727,7 +727,6 @@ def update_assets(p, s, c): if __name__ == "__main__": demo.launch( - server_name="10.34.8.77", server_port=8088, allowed_paths=[ "/horizon-bucket/robot_lab/datasets/embodiedgen/assets" diff --git a/LICENSE b/docs/LICENSE similarity index 100% rename from LICENSE rename to docs/LICENSE diff --git a/docs/install.md b/docs/install.md index 8262eba..cf01f06 100644 --- a/docs/install.md +++ b/docs/install.md @@ -14,6 +14,8 @@ conda activate embodiedgen bash install.sh basic ``` +Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards. + ## ✅ Starting from Docker We provide a pre-built Docker image on [Docker Hub](https://hub.docker.com/repository/docker/wangxinjie/embodiedgen) with a configured environment for your convenience. For more details, please refer to [Docker documentation](https://github.com/HorizonRobotics/EmbodiedGen/tree/master/docker). diff --git a/embodied_gen/data/asset_converter.py b/embodied_gen/data/asset_converter.py index 5499e53..71ef27e 100644 --- a/embodied_gen/data/asset_converter.py +++ b/embodied_gen/data/asset_converter.py @@ -589,6 +589,8 @@ def convert(self, urdf_path: str, output_file: str): stage = Usd.Stage.Open(usd_path) layer = stage.GetRootLayer() with Usd.EditContext(stage, layer): + base_prim = stage.GetPseudoRoot().GetChildren()[0] + base_prim.SetMetadata("kind", "component") for prim in stage.Traverse(): # Change texture path to relative path. if prim.GetName() == "material_0": diff --git a/embodied_gen/data/backproject.py b/embodied_gen/data/backproject.py index b02ae27..94d57d0 100644 --- a/embodied_gen/data/backproject.py +++ b/embodied_gen/data/backproject.py @@ -34,6 +34,7 @@ CameraSetting, get_images_from_grid, init_kal_camera, + kaolin_to_opencv_view, normalize_vertices_array, post_process_texture, save_mesh_with_mtl, @@ -306,28 +307,6 @@ def bake_texture( raise ValueError(f"Unknown mode: {mode}") -def kaolin_to_opencv_view(raw_matrix): - R_orig = raw_matrix[:, :3, :3] - t_orig = raw_matrix[:, :3, 3] - - R_target = torch.zeros_like(R_orig) - R_target[:, :, 0] = R_orig[:, :, 2] - R_target[:, :, 1] = R_orig[:, :, 0] - R_target[:, :, 2] = R_orig[:, :, 1] - - t_target = t_orig - - target_matrix = ( - torch.eye(4, device=raw_matrix.device) - .unsqueeze(0) - .repeat(raw_matrix.size(0), 1, 1) - ) - target_matrix[:, :3, :3] = R_target - target_matrix[:, :3, 3] = t_target - - return target_matrix - - def parse_args(): parser = argparse.ArgumentParser(description="Render settings") diff --git a/embodied_gen/data/backproject_v3.py b/embodied_gen/data/backproject_v3.py new file mode 100644 index 0000000..b22b497 --- /dev/null +++ b/embodied_gen/data/backproject_v3.py @@ -0,0 +1,558 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +import argparse +import logging +import math +import os +from typing import Literal, Union + +import cv2 +import numpy as np +import nvdiffrast.torch as dr +import spaces +import torch +import trimesh +import utils3d +import xatlas +from PIL import Image +from tqdm import tqdm +from embodied_gen.data.mesh_operator import MeshFixer +from embodied_gen.data.utils import ( + CameraSetting, + init_kal_camera, + kaolin_to_opencv_view, + normalize_vertices_array, + post_process_texture, + save_mesh_with_mtl, +) +from embodied_gen.models.delight_model import DelightingModel +from embodied_gen.models.gs_model import load_gs_model +from embodied_gen.models.sr_model import ImageRealESRGAN + +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO +) +logger = logging.getLogger(__name__) + + +__all__ = [ + "TextureBaker", +] + + +class TextureBaker(object): + """Baking textures onto a mesh from multiple observations. + + This class take 3D mesh data, camera settings and texture baking parameters + to generate texture map by projecting images to the mesh from diff views. + It supports both a fast texture baking approach and a more optimized method + with total variation regularization. + + Attributes: + vertices (torch.Tensor): The vertices of the mesh. + faces (torch.Tensor): The faces of the mesh, defined by vertex indices. + uvs (torch.Tensor): The UV coordinates of the mesh. + camera_params (CameraSetting): Camera setting (intrinsics, extrinsics). + device (str): The device to run computations on ("cpu" or "cuda"). + w2cs (torch.Tensor): World-to-camera transformation matrices. + projections (torch.Tensor): Camera projection matrices. + + Example: + >>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa + >>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params) + >>> images = get_images_from_grid(args.color_path, image_size) + >>> texture = texture_backer.bake_texture( + ... images, texture_size=args.texture_size, mode=args.baker_mode + ... ) + >>> texture = post_process_texture(texture) + """ + + def __init__( + self, + vertices: np.ndarray, + faces: np.ndarray, + uvs: np.ndarray, + camera_params: CameraSetting, + device: str = "cuda", + ) -> None: + self.vertices = ( + torch.tensor(vertices, device=device) + if isinstance(vertices, np.ndarray) + else vertices.to(device) + ) + self.faces = ( + torch.tensor(faces.astype(np.int32), device=device) + if isinstance(faces, np.ndarray) + else faces.to(device) + ) + self.uvs = ( + torch.tensor(uvs, device=device) + if isinstance(uvs, np.ndarray) + else uvs.to(device) + ) + self.camera_params = camera_params + self.device = device + + camera = init_kal_camera(camera_params) + matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam + matrix_mv = kaolin_to_opencv_view(matrix_mv) + matrix_p = ( + camera.intrinsics.projection_matrix() + ) # (n_cam 4 4) cam2pixel + self.w2cs = matrix_mv.to(self.device) + self.projections = matrix_p.to(self.device) + + @staticmethod + def parametrize_mesh( + vertices: np.array, faces: np.array + ) -> Union[np.array, np.array, np.array]: + vmapping, indices, uvs = xatlas.parametrize(vertices, faces) + + vertices = vertices[vmapping] + faces = indices + + return vertices, faces, uvs + + def _bake_fast(self, observations, w2cs, projections, texture_size, masks): + texture = torch.zeros( + (texture_size * texture_size, 3), dtype=torch.float32 + ).cuda() + texture_weights = torch.zeros( + (texture_size * texture_size), dtype=torch.float32 + ).cuda() + rastctx = utils3d.torch.RastContext(backend="cuda") + for observation, w2c, projection in tqdm( + zip(observations, w2cs, projections), + total=len(observations), + desc="Texture baking (fast)", + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + self.vertices[None], + self.faces, + observation.shape[1], + observation.shape[0], + uv=self.uvs[None], + view=w2c, + projection=projection, + ) + uv_map = rast["uv"][0].detach().flip(0) + mask = rast["mask"][0].detach().bool() & masks[0] + + # nearest neighbor interpolation + uv_map = (uv_map * texture_size).floor().long() + obs = observation[mask] + uv_map = uv_map[mask] + idx = ( + uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size + ) + texture = texture.scatter_add( + 0, idx.view(-1, 1).expand(-1, 3), obs + ) + texture_weights = texture_weights.scatter_add( + 0, + idx, + torch.ones( + (obs.shape[0]), dtype=torch.float32, device=texture.device + ), + ) + + mask = texture_weights > 0 + texture[mask] /= texture_weights[mask][:, None] + texture = np.clip( + texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255, + 0, + 255, + ).astype(np.uint8) + + # inpaint + mask = ( + (texture_weights == 0) + .cpu() + .numpy() + .astype(np.uint8) + .reshape(texture_size, texture_size) + ) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + return texture + + def _bake_opt( + self, + observations, + w2cs, + projections, + texture_size, + lambda_tv, + masks, + total_steps, + ): + rastctx = utils3d.torch.RastContext(backend="cuda") + observations = [observations.flip(0) for observations in observations] + masks = [m.flip(0) for m in masks] + _uv = [] + _uv_dr = [] + for observation, w2c, projection in tqdm( + zip(observations, w2cs, projections), + total=len(w2cs), + ): + with torch.no_grad(): + rast = utils3d.torch.rasterize_triangle_faces( + rastctx, + self.vertices[None], + self.faces, + observation.shape[1], + observation.shape[0], + uv=self.uvs[None], + view=w2c, + projection=projection, + ) + _uv.append(rast["uv"].detach()) + _uv_dr.append(rast["uv_dr"].detach()) + + texture = torch.nn.Parameter( + torch.zeros( + (1, texture_size, texture_size, 3), dtype=torch.float32 + ).cuda() + ) + optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2) + + def cosine_anealing(step, total_steps, start_lr, end_lr): + return end_lr + 0.5 * (start_lr - end_lr) * ( + 1 + np.cos(np.pi * step / total_steps) + ) + + def tv_loss(texture): + return torch.nn.functional.l1_loss( + texture[:, :-1, :, :], texture[:, 1:, :, :] + ) + torch.nn.functional.l1_loss( + texture[:, :, :-1, :], texture[:, :, 1:, :] + ) + + with tqdm(total=total_steps, desc="Texture baking") as pbar: + for step in range(total_steps): + optimizer.zero_grad() + selected = np.random.randint(0, len(w2cs)) + uv, uv_dr, observation, mask = ( + _uv[selected], + _uv_dr[selected], + observations[selected], + masks[selected], + ) + render = dr.texture(texture, uv, uv_dr)[0] + loss = torch.nn.functional.l1_loss( + render[mask], observation[mask] + ) + if lambda_tv > 0: + loss += lambda_tv * tv_loss(texture) + loss.backward() + optimizer.step() + + optimizer.param_groups[0]["lr"] = cosine_anealing( + step, total_steps, 1e-2, 1e-5 + ) + pbar.set_postfix({"loss": loss.item()}) + pbar.update() + + texture = np.clip( + texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255 + ).astype(np.uint8) + mask = 1 - utils3d.torch.rasterize_triangle_faces( + rastctx, + (self.uvs * 2 - 1)[None], + self.faces, + texture_size, + texture_size, + )["mask"][0].detach().cpu().numpy().astype(np.uint8) + texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA) + + return texture + + def bake_texture( + self, + images: list[np.array], + texture_size: int = 1024, + mode: Literal["fast", "opt"] = "opt", + lambda_tv: float = 1e-2, + opt_step: int = 2000, + ): + masks = [np.any(img > 0, axis=-1) for img in images] + masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks] + images = [ + torch.tensor(obs / 255.0).float().to(self.device) for obs in images + ] + + if mode == "fast": + return self._bake_fast( + images, self.w2cs, self.projections, texture_size, masks + ) + elif mode == "opt": + return self._bake_opt( + images, + self.w2cs, + self.projections, + texture_size, + lambda_tv, + masks, + opt_step, + ) + else: + raise ValueError(f"Unknown mode: {mode}") + + +def parse_args(): + """Parses command-line arguments for texture backprojection. + + Returns: + argparse.Namespace: Parsed arguments. + """ + parser = argparse.ArgumentParser(description="Backproject texture") + parser.add_argument( + "--gs_path", + type=str, + help="Path to the GS.ply gaussian splatting model", + ) + parser.add_argument( + "--mesh_path", + type=str, + help="Mesh path, .obj, .glb or .ply", + ) + parser.add_argument( + "--output_path", + type=str, + help="Output mesh path with suffix", + ) + parser.add_argument( + "--num_images", + type=int, + default=180, + help="Number of images to render.", + ) + parser.add_argument( + "--elevation", + nargs="+", + type=float, + default=list(range(85, -90, -10)), + help="Elevation angles for the camera", + ) + parser.add_argument( + "--distance", + type=float, + default=5, + help="Camera distance (default: 5)", + ) + parser.add_argument( + "--resolution_hw", + type=int, + nargs=2, + default=(512, 512), + help="Resolution of the render images (default: (512, 512))", + ) + parser.add_argument( + "--fov", + type=float, + default=30, + help="Field of view in degrees (default: 30)", + ) + parser.add_argument( + "--device", + type=str, + choices=["cpu", "cuda"], + default="cuda", + help="Device to run on (default: `cuda`)", + ) + parser.add_argument( + "--skip_fix_mesh", action="store_true", help="Fix mesh geometry." + ) + parser.add_argument( + "--texture_size", + type=int, + default=2048, + help="Texture size for texture baking (default: 1024)", + ) + parser.add_argument( + "--baker_mode", + type=str, + default="opt", + help="Texture baking mode, `fast` or `opt` (default: opt)", + ) + parser.add_argument( + "--opt_step", + type=int, + default=3000, + help="Optimization steps for texture baking (default: 3000)", + ) + parser.add_argument( + "--mesh_sipmlify_ratio", + type=float, + default=0.9, + help="Mesh simplification ratio (default: 0.9)", + ) + parser.add_argument( + "--delight", action="store_true", help="Use delighting model." + ) + parser.add_argument( + "--no_smooth_texture", + action="store_true", + help="Do not smooth the texture.", + ) + parser.add_argument( + "--no_coor_trans", + action="store_true", + help="Do not transform the asset coordinate system.", + ) + parser.add_argument( + "--save_glb_path", type=str, default=None, help="Save glb path." + ) + parser.add_argument("--n_max_faces", type=int, default=30000) + args, unknown = parser.parse_known_args() + + return args + + +@spaces.GPU +def entrypoint( + delight_model: DelightingModel = None, + imagesr_model: ImageRealESRGAN = None, + **kwargs, +) -> trimesh.Trimesh: + """Entrypoint for texture backprojection from multi-view images. + + Args: + delight_model (DelightingModel, optional): Delighting model. + imagesr_model (ImageRealESRGAN, optional): Super-resolution model. + **kwargs: Additional arguments to override CLI. + + Returns: + trimesh.Trimesh: Textured mesh. + """ + args = parse_args() + for k, v in kwargs.items(): + if hasattr(args, k) and v is not None: + setattr(args, k, v) + + # Setup camera parameters. + camera_params = CameraSetting( + num_images=args.num_images, + elevation=args.elevation, + distance=args.distance, + resolution_hw=args.resolution_hw, + fov=math.radians(args.fov), + device=args.device, + ) + + # GS render. + camera = init_kal_camera(camera_params, flip_az=True) + matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam + matrix_mv[:, :3, 3] = -matrix_mv[:, :3, 3] + w2cs = matrix_mv.to(camera_params.device) + c2ws = [torch.linalg.inv(matrix) for matrix in w2cs] + Ks = torch.tensor(camera_params.Ks).to(camera_params.device) + gs_model = load_gs_model(args.gs_path, pre_quat=[0.0, 0.0, 1.0, 0.0]) + multiviews = [] + for idx in tqdm(range(len(c2ws)), desc="Rendering GS"): + result = gs_model.render( + c2ws[idx], + Ks=Ks, + image_width=camera_params.resolution_hw[1], + image_height=camera_params.resolution_hw[0], + ) + color = cv2.cvtColor(result.rgba, cv2.COLOR_BGRA2RGBA) + multiviews.append(Image.fromarray(color)) + + if args.delight and delight_model is None: + delight_model = DelightingModel() + + if args.delight: + for idx in range(len(multiviews)): + multiviews[idx] = delight_model(multiviews[idx]) + + multiviews = [img.convert("RGB") for img in multiviews] + + mesh = trimesh.load(args.mesh_path) + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + vertices, scale, center = normalize_vertices_array(mesh.vertices) + + # Transform mesh coordinate system by default. + if not args.no_coor_trans: + x_rot = np.array([[1, 0, 0], [0, 0, 1], [0, -1, 0]]) + z_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + vertices = vertices @ x_rot + vertices = vertices @ z_rot + + faces = mesh.faces.astype(np.int32) + vertices = vertices.astype(np.float32) + + if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces: + mesh_fixer = MeshFixer(vertices, faces, args.device) + vertices, faces = mesh_fixer( + filter_ratio=args.mesh_sipmlify_ratio, + max_hole_size=0.04, + resolution=1024, + num_views=1000, + norm_mesh_ratio=0.5, + ) + if len(faces) > args.n_max_faces: + mesh_fixer = MeshFixer(vertices, faces, args.device) + vertices, faces = mesh_fixer( + filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2), + max_hole_size=0.04, + resolution=1024, + num_views=1000, + norm_mesh_ratio=0.5, + ) + + vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) + texture_backer = TextureBaker( + vertices, + faces, + uvs, + camera_params, + ) + + multiviews = [np.array(img) for img in multiviews] + texture = texture_backer.bake_texture( + images=[img[..., :3] for img in multiviews], + texture_size=args.texture_size, + mode=args.baker_mode, + opt_step=args.opt_step, + ) + if not args.no_smooth_texture: + texture = post_process_texture(texture) + + # Recover mesh original orientation, scale and center. + if not args.no_coor_trans: + vertices = vertices @ np.linalg.inv(z_rot) + vertices = vertices @ np.linalg.inv(x_rot) + vertices = vertices / scale + vertices = vertices + center + + textured_mesh = save_mesh_with_mtl( + vertices, faces, uvs, texture, args.output_path + ) + if args.save_glb_path is not None: + os.makedirs(os.path.dirname(args.save_glb_path), exist_ok=True) + textured_mesh.export(args.save_glb_path) + + return textured_mesh + + +if __name__ == "__main__": + entrypoint() diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py index 98900d9..fa2f7d5 100644 --- a/embodied_gen/data/utils.py +++ b/embodied_gen/data/utils.py @@ -66,6 +66,7 @@ "resize_pil", "trellis_preprocess", "delete_dir", + "kaolin_to_opencv_view", ] @@ -373,10 +374,18 @@ def _compute_az_el_by_views( def _compute_cam_pts_by_az_el( azs: np.ndarray, els: np.ndarray, - distance: float, + distance: float | list[float] | np.ndarray, extra_pts: np.ndarray = None, ) -> np.ndarray: - distances = np.array([distance for _ in range(len(azs))]) + if np.isscalar(distance) or isinstance(distance, (float, int)): + distances = np.full(len(azs), distance) + else: + distances = np.array(distance) + if len(distances) != len(azs): + raise ValueError( + f"Length of distances ({len(distances)}) must match length of azs ({len(azs)})" + ) + cam_pts = _az_el_to_points(azs, els) * distances[:, None] if extra_pts is not None: @@ -710,7 +719,7 @@ class CameraSetting: num_images: int elevation: list[float] - distance: float + distance: float | list[float] resolution_hw: tuple[int, int] fov: float at: tuple[float, float, float] = field( @@ -824,6 +833,28 @@ def import_kaolin_mesh(mesh_path: str, with_mtl: bool = False): return mesh +def kaolin_to_opencv_view(raw_matrix): + R_orig = raw_matrix[:, :3, :3] + t_orig = raw_matrix[:, :3, 3] + + R_target = torch.zeros_like(R_orig) + R_target[:, :, 0] = R_orig[:, :, 2] + R_target[:, :, 1] = R_orig[:, :, 0] + R_target[:, :, 2] = R_orig[:, :, 1] + + t_target = t_orig + + target_matrix = ( + torch.eye(4, device=raw_matrix.device) + .unsqueeze(0) + .repeat(raw_matrix.size(0), 1, 1) + ) + target_matrix[:, :3, :3] = R_target + target_matrix[:, :3, 3] = t_target + + return target_matrix + + def save_mesh_with_mtl( vertices: np.ndarray, faces: np.ndarray, diff --git a/embodied_gen/models/gs_model.py b/embodied_gen/models/gs_model.py index 7b40b56..003abd3 100644 --- a/embodied_gen/models/gs_model.py +++ b/embodied_gen/models/gs_model.py @@ -21,14 +21,18 @@ from dataclasses import dataclass from typing import Optional -import cv2 import numpy as np import torch from gsplat.cuda._wrapper import spherical_harmonics from gsplat.rendering import rasterization from plyfile import PlyData from scipy.spatial.transform import Rotation -from embodied_gen.data.utils import gamma_shs, quat_mult, quat_to_rotmat +from embodied_gen.data.utils import ( + gamma_shs, + normalize_vertices_array, + quat_mult, + quat_to_rotmat, +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -494,6 +498,21 @@ def render( ) +def load_gs_model( + input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071] +) -> GaussianOperator: + gs_model = GaussianOperator.load_from_ply(input_gs) + # Normalize vertices to [-1, 1], center to (0, 0, 0). + _, scale, center = normalize_vertices_array(gs_model._means) + scale, center = float(scale), center.tolist() + transpose = [*[v for v in center], *pre_quat] + instance_pose = torch.tensor(transpose).to(gs_model.device) + gs_model = gs_model.get_gaussians(instance_pose=instance_pose) + gs_model.rescale(scale) + + return gs_model + + if __name__ == "__main__": input_gs = "outputs/layouts_gens_demo/task_0000/background/gs_model.ply" output_gs = "./gs_model.ply" diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py index 19719e1..61a14b6 100644 --- a/embodied_gen/scripts/imageto3d.py +++ b/embodied_gen/scripts/imageto3d.py @@ -26,12 +26,14 @@ import torch import trimesh from PIL import Image -from embodied_gen.data.backproject_v2 import entrypoint as backproject_api +from embodied_gen.data.backproject_v3 import entrypoint as backproject_api from embodied_gen.data.utils import delete_dir, trellis_preprocess -from embodied_gen.models.delight_model import DelightingModel + +# from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.segment_model import RembgRemover -from embodied_gen.models.sr_model import ImageRealESRGAN + +# from embodied_gen.models.sr_model import ImageRealESRGAN from embodied_gen.scripts.render_gs import entrypoint as render_gs_api from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.log import logger @@ -59,8 +61,8 @@ random.seed(0) logger.info("Loading Image3D Models...") -DELIGHT = DelightingModel() -IMAGESR_MODEL = ImageRealESRGAN(outscale=4) +# DELIGHT = DelightingModel() +# IMAGESR_MODEL = ImageRealESRGAN(outscale=4) RBG_REMOVER = RembgRemover() PIPELINE = TrellisImageTo3DPipeline.from_pretrained( "microsoft/TRELLIS-image-large" @@ -108,9 +110,7 @@ def parse_args(): default=2, ) parser.add_argument("--disable_decompose_convex", action="store_true") - parser.add_argument( - "--texture_wh", type=int, nargs=2, default=[2048, 2048] - ) + parser.add_argument("--texture_size", type=int, default=2048) args, unknown = parser.parse_known_args() return args @@ -248,16 +248,14 @@ def entrypoint(**kwargs): mesh.export(mesh_obj_path) mesh = backproject_api( - delight_model=DELIGHT, - imagesr_model=IMAGESR_MODEL, - color_path=color_path, + # delight_model=DELIGHT, + # imagesr_model=IMAGESR_MODEL, + gs_path=aligned_gs_path, mesh_path=mesh_obj_path, output_path=mesh_obj_path, skip_fix_mesh=False, - delight=True, - texture_wh=args.texture_wh, - elevation=[20, -10, 60, -50], - num_images=12, + texture_size=args.texture_size, + delight=False, ) mesh_glb_path = os.path.join(output_root, f"{filename}.glb") diff --git a/embodied_gen/scripts/render_gs.py b/embodied_gen/scripts/render_gs.py index 2c8459d..3a3d7a2 100644 --- a/embodied_gen/scripts/render_gs.py +++ b/embodied_gen/scripts/render_gs.py @@ -29,7 +29,7 @@ init_kal_camera, normalize_vertices_array, ) -from embodied_gen.models.gs_model import GaussianOperator +from embodied_gen.models.gs_model import load_gs_model from embodied_gen.utils.process_media import combine_images_to_grid logging.basicConfig( @@ -97,21 +97,6 @@ def parse_args(): return args -def load_gs_model( - input_gs: str, pre_quat: list[float] = [0.0, 0.7071, 0.0, -0.7071] -) -> GaussianOperator: - gs_model = GaussianOperator.load_from_ply(input_gs) - # Normalize vertices to [-1, 1], center to (0, 0, 0). - _, scale, center = normalize_vertices_array(gs_model._means) - scale, center = float(scale), center.tolist() - transpose = [*[v for v in center], *pre_quat] - instance_pose = torch.tensor(transpose).to(gs_model.device) - gs_model = gs_model.get_gaussians(instance_pose=instance_pose) - gs_model.rescale(scale) - - return gs_model - - @spaces.GPU def entrypoint(**kwargs) -> None: args = parse_args() diff --git a/mkdocs.yml b/mkdocs.yml index d55214c..d5a421c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,7 +94,6 @@ plugins: docstring_style: google show_source: true merge_init_into_class: true - show_inherited_members: true show_root_heading: true show_root_full_path: true