From 957d16b9bf0a85426b0ea13b1c835922468bec10 Mon Sep 17 00:00:00 2001 From: "xinjie.wang" Date: Wed, 24 Dec 2025 20:24:27 +0800 Subject: [PATCH] feat(model): Add sam3d model --- .gitmodules | 5 + README.md | 20 +- apps/app_style.py | 18 +- apps/common.py | 259 +++-------- apps/image_to_3d.py | 50 +- apps/text_to_3d.py | 39 +- apps/texture_edit.py | 3 +- apps/visualize_asset.py | 4 +- docs/acknowledgement.md | 2 +- docs/install.md | 5 +- docs/services/image_to_3d.md | 2 + docs/tutorials/image_to_3d.md | 3 +- docs/tutorials/layout_gen.md | 1 + docs/tutorials/text_to_3d.md | 4 +- embodied_gen/data/asset_converter.py | 17 + embodied_gen/data/backproject_v2.py | 12 +- embodied_gen/data/backproject_v3.py | 12 +- embodied_gen/data/utils.py | 186 ++------ embodied_gen/models/sam3d.py | 152 +++++++ embodied_gen/models/segment_model.py | 24 +- embodied_gen/scripts/gen_scene3d.py | 17 + embodied_gen/scripts/gen_texture.py | 18 + embodied_gen/scripts/imageto3d.py | 84 ++-- embodied_gen/scripts/render_gs.py | 1 - embodied_gen/scripts/textto3d.py | 6 +- embodied_gen/utils/gpt_clients.py | 13 +- embodied_gen/utils/inference.py | 59 +++ embodied_gen/utils/monkey_patches.py | 427 ++++++++++++++++++ embodied_gen/utils/process_media.py | 31 +- embodied_gen/utils/tags.py | 2 +- embodied_gen/utils/trender.py | 113 ++++- .../validators/aesthetic_predictor.py | 6 +- embodied_gen/validators/quality_checkers.py | 33 +- embodied_gen/validators/urdf_convertor.py | 7 +- install/install_basic.sh | 4 +- install/install_extra.sh | 6 +- pyproject.toml | 2 +- requirements.txt | 11 +- tests/test_examples/test_quality_checkers.py | 8 +- 39 files changed, 1192 insertions(+), 474 deletions(-) create mode 100644 embodied_gen/models/sam3d.py create mode 100644 embodied_gen/utils/inference.py diff --git a/.gitmodules b/.gitmodules index c6b0a7b..d215578 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,3 +8,8 @@ url = https://github.com/TrickyGo/Pano2Room.git branch = main shallow = true +[submodule "thirdparty/sam3d"] + path = thirdparty/sam3d + url = https://github.com/HochCC/sam-3d-objects.git + branch = main + shallow = true diff --git a/README.md b/README.md index 80902a6..9a14480 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,12 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.6 +git checkout v0.1.7 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen -bash install.sh basic +bash install.sh basic # around 20 mins +# Optional: `bash install.sh extra` for scene3d-cli ``` ### ✅ Starting from Docker @@ -94,12 +95,14 @@ CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 & ### ⚡ API Generate physically plausible 3D assets from image input via the command-line API. ```sh -img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \ +img3d-cli --image_path apps/assets/example_image/sample_00.jpg apps/assets/example_image/sample_01.jpg \ --n_retry 1 --output_root outputs/imageto3d # See result(.urdf/mesh.obj/mesh.glb/gs.ply) in ${output_root}/sample_xx/result ``` +Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `IMAGE3D_MODEL` in `embodied_gen/scripts/imageto3d.py` to switch model. + --- @@ -133,7 +136,7 @@ text3d-cli --prompts "small bronze figurine of a lion" "A globe with wooden base Text-to-image model based on the Kolors model. ```sh bash embodied_gen/scripts/textto3d.sh \ - --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ + --prompts "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ --output_root outputs/textto3d_k ``` ps: models with more permissive licenses found in `embodied_gen/models/image_comm_model.py` @@ -191,7 +194,11 @@ CUDA_VISIBLE_DEVICES=0 scene3d-cli \

⚙️ Articulated Object Generation

-🚧 *Coming Soon* +See our paper published in NeurIPS 2025. +[[Arxiv Paper]](https://arxiv.org/abs/2505.20460) | +[[Gradio Demo]](https://huggingface.co/spaces/HorizonRobotics/DIPO) | +[[Code]](https://github.com/RQ-Wu/DIPO) + articulate @@ -239,6 +246,7 @@ Remove `--insert_robot` if you don't consider the robot pose in layout generatio CUDA_VISIBLE_DEVICES=0 nohup layout-cli \ --task_descs "apps/assets/example_layout/task_list.txt" \ --bg_list "outputs/bg_scenes/scene_list.txt" \ +--n_image_retry 4 --n_asset_retry 3 --n_pipe_retry 2 \ --output_root "outputs/layouts_gens" --insert_robot > layouts_gens.log & ``` @@ -325,7 +333,7 @@ If you use EmbodiedGen in your research or projects, please cite: ## 🙌 Acknowledgement EmbodiedGen builds upon the following amazing projects and models: -🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) +🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) | 🌟 [SAM3D](https://github.com/facebookresearch/sam-3d-objects) --- diff --git a/apps/app_style.py b/apps/app_style.py index a552f9f..313ccd1 100644 --- a/apps/app_style.py +++ b/apps/app_style.py @@ -1,10 +1,26 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + from gradio.themes import Soft from gradio.themes.utils.colors import gray, neutral, slate, stone, teal, zinc lighting_css = """ """ diff --git a/apps/common.py b/apps/common.py index 3fcac18..55e30da 100644 --- a/apps/common.py +++ b/apps/common.py @@ -14,6 +14,11 @@ # implied. See the License for the specific language governing # permissions and limitations under the License. +import spaces +from embodied_gen.utils.monkey_patches import monkey_path_trellis + +monkey_path_trellis() + import gc import logging import os @@ -25,18 +30,16 @@ import cv2 import gradio as gr import numpy as np -import spaces import torch -import torch.nn.functional as F import trimesh -from easydict import EasyDict as edict from PIL import Image from embodied_gen.data.backproject_v2 import entrypoint as backproject_api from embodied_gen.data.backproject_v3 import entrypoint as backproject_api_v3 from embodied_gen.data.differentiable_render import entrypoint as render_api -from embodied_gen.data.utils import resize_pil, trellis_preprocess, zip_files +from embodied_gen.data.utils import trellis_preprocess, zip_files from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.gs_model import GaussianOperator +from embodied_gen.models.sam3d import Sam3dInference from embodied_gen.models.segment_model import ( BMGG14Remover, RembgRemover, @@ -53,10 +56,11 @@ from embodied_gen.utils.gpt_clients import GPT_CLIENT from embodied_gen.utils.process_media import ( filter_image_small_connected_components, + keep_largest_connected_component, merge_images_video, ) from embodied_gen.utils.tags import VERSION -from embodied_gen.utils.trender import render_video +from embodied_gen.utils.trender import pack_state, render_video, unpack_state from embodied_gen.validators.quality_checkers import ( BaseChecker, ImageAestheticChecker, @@ -69,15 +73,6 @@ current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "..")) from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline -from thirdparty.TRELLIS.trellis.representations import ( - Gaussian, - MeshExtractResult, -) -from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import ( - build_scaling_rotation, - inverse_sigmoid, - strip_symmetric, -) from thirdparty.TRELLIS.trellis.utils import postprocessing_utils logging.basicConfig( @@ -85,64 +80,24 @@ ) logger = logging.getLogger(__name__) - -os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( - "~/.cache/torch_extensions" -) os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" -os.environ["SPCONV_ALGO"] = "native" MAX_SEED = 100000 - -def patched_setup_functions(self): - def inverse_softplus(x): - return x + torch.log(-torch.expm1(-x)) - - def build_covariance_from_scaling_rotation( - scaling, scaling_modifier, rotation - ): - L = build_scaling_rotation(scaling_modifier * scaling, rotation) - actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance) - return symm - - if self.scaling_activation_type == "exp": - self.scaling_activation = torch.exp - self.inverse_scaling_activation = torch.log - elif self.scaling_activation_type == "softplus": - self.scaling_activation = F.softplus - self.inverse_scaling_activation = inverse_softplus - - self.covariance_activation = build_covariance_from_scaling_rotation - self.opacity_activation = torch.sigmoid - self.inverse_opacity_activation = inverse_sigmoid - self.rotation_activation = F.normalize - - self.scale_bias = self.inverse_scaling_activation( - torch.tensor(self.scaling_bias) - ).to(self.device) - self.rots_bias = torch.zeros((4)).to(self.device) - self.rots_bias[0] = 1 - self.opacity_bias = self.inverse_opacity_activation( - torch.tensor(self.opacity_bias) - ).to(self.device) - - -Gaussian.setup_functions = patched_setup_functions - - # DELIGHT = DelightingModel() # IMAGESR_MODEL = ImageRealESRGAN(outscale=4) # IMAGESR_MODEL = ImageStableSR() -if os.getenv("GRADIO_APP") == "imageto3d": +if os.getenv("GRADIO_APP").startswith("imageto3d"): RBG_REMOVER = RembgRemover() RBG14_REMOVER = BMGG14Remover() SAM_PREDICTOR = SAMPredictor(model_type="vit_h", device="cpu") - PIPELINE = TrellisImageTo3DPipeline.from_pretrained( - "microsoft/TRELLIS-image-large" - ) - # PIPELINE.cuda() + if "sam3d" in os.getenv("GRADIO_APP"): + PIPELINE = Sam3dInference() + else: + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() SEG_CHECKER = ImageSegChecker(GPT_CLIENT) GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) AESTHETIC_CHECKER = ImageAestheticChecker() @@ -151,13 +106,16 @@ def build_covariance_from_scaling_rotation( os.path.dirname(os.path.abspath(__file__)), "sessions/imageto3d" ) os.makedirs(TMP_DIR, exist_ok=True) -elif os.getenv("GRADIO_APP") == "textto3d": +elif os.getenv("GRADIO_APP").startswith("textto3d"): RBG_REMOVER = RembgRemover() RBG14_REMOVER = BMGG14Remover() - PIPELINE = TrellisImageTo3DPipeline.from_pretrained( - "microsoft/TRELLIS-image-large" - ) - # PIPELINE.cuda() + if "sam3d" in os.getenv("GRADIO_APP"): + PIPELINE = Sam3dInference() + else: + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() text_model_dir = "weights/Kolors" PIPELINE_IMG_IP = build_text2img_ip_pipeline(text_model_dir, ref_scale=0.3) PIPELINE_IMG = build_text2img_pipeline(text_model_dir) @@ -201,18 +159,23 @@ def end_session(req: gr.Request) -> None: @spaces.GPU def preprocess_image_fn( - image: str | np.ndarray | Image.Image, rmbg_tag: str = "rembg" + image: str | np.ndarray | Image.Image, + rmbg_tag: str = "rembg", + preprocess: bool = True, ) -> tuple[Image.Image, Image.Image]: if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) - image_cache = resize_pil(image.copy(), 1024) + image_cache = image.copy() # resize_pil(image.copy(), 1024) bg_remover = RBG_REMOVER if rmbg_tag == "rembg" else RBG14_REMOVER image = bg_remover(image) - image = trellis_preprocess(image) + image = keep_largest_connected_component(image) + + if preprocess: + image = trellis_preprocess(image) return image, image_cache @@ -264,50 +227,6 @@ def get_cached_image(image_path: str) -> Image.Image: return Image.open(image_path).resize((512, 512)) -@spaces.GPU -def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: - return { - "gaussian": { - **gs.init_params, - "_xyz": gs._xyz.cpu().numpy(), - "_features_dc": gs._features_dc.cpu().numpy(), - "_scaling": gs._scaling.cpu().numpy(), - "_rotation": gs._rotation.cpu().numpy(), - "_opacity": gs._opacity.cpu().numpy(), - }, - "mesh": { - "vertices": mesh.vertices.cpu().numpy(), - "faces": mesh.faces.cpu().numpy(), - }, - } - - -def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]: - gs = Gaussian( - aabb=state["gaussian"]["aabb"], - sh_degree=state["gaussian"]["sh_degree"], - mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], - scaling_bias=state["gaussian"]["scaling_bias"], - opacity_bias=state["gaussian"]["opacity_bias"], - scaling_activation=state["gaussian"]["scaling_activation"], - device=device, - ) - gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device) - gs._features_dc = torch.tensor( - state["gaussian"]["_features_dc"], device=device - ) - gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device) - gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device) - gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device) - - mesh = edict( - vertices=torch.tensor(state["mesh"]["vertices"], device=device), - faces=torch.tensor(state["mesh"]["faces"], device=device), - ) - - return gs, mesh - - def get_seed(randomize_seed: bool, seed: int, max_seed: int = MAX_SEED) -> int: return np.random.randint(0, max_seed) if randomize_seed else seed @@ -349,11 +268,11 @@ def select_point( def image_to_3d( image: Image.Image, seed: int, - ss_guidance_strength: float, ss_sampling_steps: int, - slat_guidance_strength: float, slat_sampling_steps: int, raw_image_cache: Image.Image, + ss_guidance_strength: float, + slat_guidance_strength: float, sam_image: Image.Image = None, is_sam_image: bool = False, req: gr.Request = None, @@ -361,39 +280,48 @@ def image_to_3d( if is_sam_image: seg_image = filter_image_small_connected_components(sam_image) seg_image = Image.fromarray(seg_image, mode="RGBA") - seg_image = trellis_preprocess(seg_image) else: seg_image = image if isinstance(seg_image, np.ndarray): seg_image = Image.fromarray(seg_image) + if isinstance(PIPELINE, Sam3dInference): + outputs = PIPELINE.run( + seg_image, + seed=seed, + stage1_inference_steps=ss_sampling_steps, + stage2_inference_steps=slat_sampling_steps, + ) + else: + PIPELINE.cuda() + seg_image = trellis_preprocess(seg_image) + outputs = PIPELINE.run( + seg_image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + # Set back to cpu for memory saving. + PIPELINE.cpu() + + gs_model = outputs["gaussian"][0] + mesh_model = outputs["mesh"][0] + color_images = render_video(gs_model, r=1.85)["color"] + normal_images = render_video(mesh_model, r=1.85)["normal"] + output_root = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(output_root, exist_ok=True) seg_image.save(f"{output_root}/seg_image.png") raw_image_cache.save(f"{output_root}/raw_image.png") - PIPELINE.cuda() - outputs = PIPELINE.run( - seg_image, - seed=seed, - formats=["gaussian", "mesh"], - preprocess_image=False, - sparse_structure_sampler_params={ - "steps": ss_sampling_steps, - "cfg_strength": ss_guidance_strength, - }, - slat_sampler_params={ - "steps": slat_sampling_steps, - "cfg_strength": slat_guidance_strength, - }, - ) - # Set to cpu for memory saving. - PIPELINE.cpu() - - gs_model = outputs["gaussian"][0] - mesh_model = outputs["mesh"][0] - color_images = render_video(gs_model)["color"] - normal_images = render_video(mesh_model)["normal"] video_path = os.path.join(output_root, "gs_mesh.mp4") merge_images_video(color_images, normal_images, video_path) @@ -405,56 +333,13 @@ def image_to_3d( return state, video_path -@spaces.GPU -def extract_3d_representations( - state: dict, enable_delight: bool, texture_size: int, req: gr.Request -): - output_root = TMP_DIR - output_root = os.path.join(output_root, str(req.session_hash)) - gs_model, mesh_model = unpack_state(state, device="cuda") - - mesh = postprocessing_utils.to_glb( - gs_model, - mesh_model, - simplify=0.9, - texture_size=1024, - verbose=True, - ) - filename = "sample" - gs_path = os.path.join(output_root, f"{filename}_gs.ply") - gs_model.save_ply(gs_path) - - # Rotate mesh and GS by 90 degrees around Z-axis. - rot_matrix = [[0, 0, -1], [0, 1, 0], [1, 0, 0]] - # Addtional rotation for GS to align mesh. - gs_rot = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]]) @ np.array( - rot_matrix - ) - pose = GaussianOperator.trans_to_quatpose(gs_rot) - aligned_gs_path = gs_path.replace(".ply", "_aligned.ply") - GaussianOperator.resave_ply( - in_ply=gs_path, - out_ply=aligned_gs_path, - instance_pose=pose, - ) - - mesh.vertices = mesh.vertices @ np.array(rot_matrix) - mesh_obj_path = os.path.join(output_root, f"{filename}.obj") - mesh.export(mesh_obj_path) - mesh_glb_path = os.path.join(output_root, f"{filename}.glb") - mesh.export(mesh_glb_path) - - torch.cuda.empty_cache() - - return mesh_glb_path, gs_path, mesh_obj_path, aligned_gs_path - - def extract_3d_representations_v2( state: dict, enable_delight: bool, texture_size: int, req: gr.Request, ): + """Back-Projection Version of Texture Super-Resolution.""" output_root = TMP_DIR user_dir = os.path.join(output_root, str(req.session_hash)) gs_model, mesh_model = unpack_state(state, device="cpu") @@ -521,6 +406,7 @@ def extract_3d_representations_v3( texture_size: int, req: gr.Request, ): + """Back-Projection Version with Optimization-Based.""" output_root = TMP_DIR user_dir = os.path.join(output_root, str(req.session_hash)) gs_model, mesh_model = unpack_state(state, device="cpu") @@ -688,6 +574,7 @@ def text2image_fn( image_wh: int | tuple[int, int] = [1024, 1024], rmbg_tag: str = "rembg", seed: int = None, + enable_pre_resize: bool = True, n_sample: int = 3, req: gr.Request = None, ): @@ -715,7 +602,9 @@ def text2image_fn( for idx in range(len(images)): image = images[idx] - images[idx], _ = preprocess_image_fn(image, rmbg_tag) + images[idx], _ = preprocess_image_fn( + image, rmbg_tag, enable_pre_resize + ) save_paths = [] for idx, image in enumerate(images): @@ -841,6 +730,7 @@ def backproject_texture_v2( texture_size: int, enable_delight: bool = True, fix_mesh: bool = False, + no_mesh_post_process: bool = False, uuid: str = "sample", req: gr.Request = None, ) -> str: @@ -857,6 +747,7 @@ def backproject_texture_v2( skip_fix_mesh=not fix_mesh, delight=enable_delight, texture_wh=[texture_size, texture_size], + no_mesh_post_process=no_mesh_post_process, ) output_obj_mesh = os.path.join(output_dir, f"{uuid}.obj") diff --git a/apps/image_to_3d.py b/apps/image_to_3d.py index d8c1681..5bf5fc9 100644 --- a/apps/image_to_3d.py +++ b/apps/image_to_3d.py @@ -17,7 +17,9 @@ import os -os.environ["GRADIO_APP"] = "imageto3d" +# GRADIO_APP == "imageto3d_sam3d", sam3d object model, by default. +# GRADIO_APP == "imageto3d", TRELLIS model. +os.environ["GRADIO_APP"] = "imageto3d_sam3d" from glob import glob import gradio as gr @@ -37,6 +39,16 @@ start_session, ) +app_name = os.getenv("GRADIO_APP") +if app_name == "imageto3d_sam3d": + enable_pre_resize = False + sample_step = 25 + bg_rm_model_name = "rembg" # "rembg", "rmbg14" +elif app_name == "imageto3d": + enable_pre_resize = True + sample_step = 12 + bg_rm_model_name = "rembg" # "rembg", "rmbg14" + with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: gr.HTML(image_css, visible=False) gr.HTML(lighting_css, visible=False) @@ -67,7 +79,7 @@ ) with gr.Row(): - with gr.Column(scale=2): + with gr.Column(scale=3): with gr.Tabs() as input_tabs: with gr.Tab( label="Image(auto seg)", id=0 @@ -142,7 +154,7 @@ ) rmbg_tag = gr.Radio( choices=["rembg", "rmbg14"], - value="rembg", + value=bg_rm_model_name, label="Background Removal Model", ) with gr.Row(): @@ -163,7 +175,11 @@ step=0.1, ) ss_sampling_steps = gr.Slider( - 1, 50, label="Sampling Steps", value=12, step=1 + 1, + 50, + label="Sampling Steps", + value=sample_step, + step=1, ) gr.Markdown("Visual Appearance Generation") with gr.Row(): @@ -175,7 +191,11 @@ step=0.1, ) slat_sampling_steps = gr.Slider( - 1, 50, label="Sampling Steps", value=12, step=1 + 1, + 50, + label="Sampling Steps", + value=sample_step, + step=1, ) generate_btn = gr.Button( @@ -242,7 +262,7 @@ has quality inspection, open with an editor to view details. """ ) - + enable_pre_resize = gr.State(enable_pre_resize) with gr.Row() as single_image_example: examples = gr.Examples( label="Image Gallery", @@ -252,7 +272,7 @@ glob("apps/assets/example_image/*") ) ], - inputs=[image_prompt, rmbg_tag], + inputs=[image_prompt, rmbg_tag, enable_pre_resize], fn=preprocess_image_fn, outputs=[image_prompt, raw_image_cache], run_on_click=True, @@ -274,16 +294,16 @@ run_on_click=True, examples_per_page=10, ) - with gr.Column(scale=1): + with gr.Column(scale=2): gr.Markdown("
") video_output = gr.Video( label="Generated 3D Asset", autoplay=True, loop=True, - height=300, + height=400, ) model_output_gs = gr.Model3D( - label="Gaussian Representation", height=300, interactive=False + label="Gaussian Representation", height=350, interactive=False ) aligned_gs = gr.Textbox(visible=False) gr.Markdown( @@ -292,9 +312,9 @@ with gr.Row(): model_output_mesh = gr.Model3D( label="Mesh Representation", - height=300, + height=350, interactive=False, - clear_color=[0.8, 0.8, 0.8, 1], + clear_color=[0, 0, 0, 1], elem_id="lighter_mesh", ) @@ -320,7 +340,7 @@ image_prompt.upload( preprocess_image_fn, - inputs=[image_prompt, rmbg_tag], + inputs=[image_prompt, rmbg_tag, enable_pre_resize], outputs=[image_prompt, raw_image_cache], ) image_prompt.change( @@ -437,11 +457,11 @@ inputs=[ image_prompt, seed, - ss_guidance_strength, ss_sampling_steps, - slat_guidance_strength, slat_sampling_steps, raw_image_cache, + ss_guidance_strength, + slat_guidance_strength, image_seg_sam, is_samimage, ], diff --git a/apps/text_to_3d.py b/apps/text_to_3d.py index 8c9012c..e5a176d 100644 --- a/apps/text_to_3d.py +++ b/apps/text_to_3d.py @@ -17,8 +17,9 @@ import os -os.environ["GRADIO_APP"] = "textto3d" - +# GRADIO_APP == "textto3d_sam3d", sam3d object model, by default. +# GRADIO_APP == "textto3d", TRELLIS model. +os.environ["GRADIO_APP"] = "textto3d_sam3d" import gradio as gr from app_style import custom_theme, image_css, lighting_css @@ -37,6 +38,14 @@ text2image_fn, ) +app_name = os.getenv("GRADIO_APP") +if app_name == "textto3d_sam3d": + enable_pre_resize = False + sample_step = 25 +elif app_name == "textto3d": + enable_pre_resize = True + sample_step = 12 + with gr.Blocks(delete_cache=(43200, 43200), theme=custom_theme) as demo: gr.HTML(image_css, visible=False) gr.HTML(lighting_css, visible=False) @@ -101,11 +110,11 @@ ) rmbg_tag = gr.Radio( choices=["rembg", "rmbg14"], - value="rembg", + value="rmbg14", label="Background Removal Model", ) ip_adapt_scale = gr.Slider( - 0, 1, label="IP-adapter Scale", value=0.3, step=0.05 + 0, 1, label="IP-adapter Scale", value=0.7, step=0.05 ) img_guidance_scale = gr.Slider( 1, 30, label="Text Guidance Scale", value=12, step=0.2 @@ -162,7 +171,11 @@ step=0.1, ) ss_sampling_steps = gr.Slider( - 1, 50, label="Sampling Steps", value=12, step=1 + 1, + 50, + label="Sampling Steps", + value=sample_step, + step=1, ) gr.Markdown("Visual Appearance Generation") with gr.Row(): @@ -174,7 +187,11 @@ step=0.1, ) slat_sampling_steps = gr.Slider( - 1, 50, label="Sampling Steps", value=12, step=1 + 1, + 50, + label="Sampling Steps", + value=sample_step, + step=1, ) generate_btn = gr.Button( @@ -265,7 +282,7 @@ visible=False, ) gr.Markdown( - "Generated image may be poor quality due to auto seg." + "Generated image may be poor quality due to auto seg. " "Retry by adjusting text prompt, seed or switch seg model in `Image Gen Settings`." ) with gr.Row(): @@ -285,7 +302,7 @@ model_output_mesh = gr.Model3D( label="Mesh Representation", - clear_color=[0.8, 0.8, 0.8, 1], + clear_color=[0, 0, 0, 1], height=300, interactive=False, elem_id="lighter_mesh", @@ -323,6 +340,7 @@ ) output_buf = gr.State() + enable_pre_resize = gr.State(enable_pre_resize) demo.load(start_session) demo.unload(end_session) @@ -389,6 +407,7 @@ img_resolution, rmbg_tag, seed, + enable_pre_resize, ], outputs=[ image_sample1, @@ -420,11 +439,11 @@ inputs=[ select_img, seed, - ss_guidance_strength, ss_sampling_steps, - slat_guidance_strength, slat_sampling_steps, raw_image_cache, + ss_guidance_strength, + slat_guidance_strength, ], outputs=[output_buf, video_output], ).success( diff --git a/apps/texture_edit.py b/apps/texture_edit.py index 722ce6c..01afe13 100644 --- a/apps/texture_edit.py +++ b/apps/texture_edit.py @@ -267,7 +267,7 @@ def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox): demo.load(start_session) demo.unload(end_session) - + no_mesh_post_process = gr.State(True) mesh_input.change( lambda: tuple( [ @@ -368,6 +368,7 @@ def active_btn_by_content(mesh_content: gr.Model3D, text_content: gr.Textbox): texture_size, project_delight, fix_mesh, + no_mesh_post_process, ], outputs=[mesh_output, mesh_outpath, download_btn], ).success( diff --git a/apps/visualize_asset.py b/apps/visualize_asset.py index 5e9b94b..233df4a 100644 --- a/apps/visualize_asset.py +++ b/apps/visualize_asset.py @@ -27,7 +27,6 @@ import uuid import xml.etree.ElementTree as ET from pathlib import Path -from typing import Any, Dict, Tuple import gradio as gr import pandas as pd @@ -255,8 +254,7 @@ def search_assets(query: str, top_k: int): return items, gr.update(interactive=True), top_assets -# --- Mesh extraction --- -def _extract_mesh_paths(row) -> Tuple[str | None, str | None, str]: +def _extract_mesh_paths(row) -> tuple[str | None, str | None, str]: desc = row["description"] urdf_path = os.path.join(DATA_ROOT, row["urdf_path"]) asset_dir = os.path.join(DATA_ROOT, row["asset_dir"]) diff --git a/docs/acknowledgement.md b/docs/acknowledgement.md index d588194..b69fdae 100644 --- a/docs/acknowledgement.md +++ b/docs/acknowledgement.md @@ -1,7 +1,7 @@ # 🙌 Acknowledgement EmbodiedGen builds upon the following amazing projects and models: -🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) +🌟 [Trellis](https://github.com/microsoft/TRELLIS) | 🌟 [Hunyuan-Delight](https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0) | 🌟 [Segment Anything](https://github.com/facebookresearch/segment-anything) | 🌟 [Rembg](https://github.com/danielgatis/rembg) | 🌟 [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4) | 🌟 [Stable Diffusion x4](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler) | 🌟 [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) | 🌟 [Kolors](https://github.com/Kwai-Kolors/Kolors) | 🌟 [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 🌟 [Aesthetic Score](http://captions.christoph-schuhmann.de/aesthetic_viz_laion_sac+logos+ava1-l14-linearMSE-en-2.37B.html) | 🌟 [Pano2Room](https://github.com/TrickyGo/Pano2Room) | 🌟 [Diffusion360](https://github.com/ArcherFMY/SD-T2I-360PanoImage) | 🌟 [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) | 🌟 [diffusers](https://github.com/huggingface/diffusers) | 🌟 [gsplat](https://github.com/nerfstudio-project/gsplat) | 🌟 [QWEN-2.5VL](https://github.com/QwenLM/Qwen2.5-VL) | 🌟 [GPT4o](https://platform.openai.com/docs/models/gpt-4o) | 🌟 [SD3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | 🌟 [ManiSkill](https://github.com/haosulab/ManiSkill) | 🌟 [SAM3D](https://github.com/facebookresearch/sam-3d-objects) --- diff --git a/docs/install.md b/docs/install.md index cf01f06..e3c3534 100644 --- a/docs/install.md +++ b/docs/install.md @@ -7,11 +7,12 @@ hide: ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.6 +git checkout v0.1.7 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen -bash install.sh basic +bash install.sh basic # around 20 mins +# Optional: `bash install.sh extra` for scene3d-cli ``` Please `huggingface-cli login` to ensure that the ckpts can be downloaded automatically afterwards. diff --git a/docs/services/image_to_3d.md b/docs/services/image_to_3d.md index 4e1a1c8..674d243 100644 --- a/docs/services/image_to_3d.md +++ b/docs/services/image_to_3d.md @@ -67,6 +67,8 @@ python apps/image_to_3d.py CUDA_VISIBLE_DEVICES=0 nohup python apps/image_to_3d.py > /dev/null 2>&1 & ``` +Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `GRADIO_APP` in `apps/image_to_3d.py` to switch model. + --- !!! tip "Getting Started" diff --git a/docs/tutorials/image_to_3d.md b/docs/tutorials/image_to_3d.md index 2f7178a..705ed35 100644 --- a/docs/tutorials/image_to_3d.md +++ b/docs/tutorials/image_to_3d.md @@ -5,10 +5,11 @@ Generate **physically plausible 3D assets** from a single input image, supportin --- ## ⚡ Command-Line Usage +Support the use of [SAM3D](https://github.com/facebookresearch/sam-3d-objects) or [TRELLIS](https://github.com/microsoft/TRELLIS) as 3D generation model, modify `IMAGE3D_MODEL` in `embodied_gen/scripts/imageto3d.py` to switch model. ```bash img3d-cli --image_path apps/assets/example_image/sample_00.jpg \ -apps/assets/example_image/sample_01.jpg apps/assets/example_image/sample_19.jpg \ +apps/assets/example_image/sample_01.jpg \ --n_retry 1 --output_root outputs/imageto3d ``` diff --git a/docs/tutorials/layout_gen.md b/docs/tutorials/layout_gen.md index 3109ff1..6eeb642 100644 --- a/docs/tutorials/layout_gen.md +++ b/docs/tutorials/layout_gen.md @@ -60,6 +60,7 @@ You can also run multiple tasks via a task list file in the backend. CUDA_VISIBLE_DEVICES=0 nohup layout-cli \ --task_descs "apps/assets/example_layout/task_list.txt" \ --bg_list "outputs/bg_scenes/scene_list.txt" \ + --n_image_retry 4 --n_asset_retry 3 --n_pipe_retry 2 \ --output_root "outputs/layouts_gens" \ --insert_robot > layouts_gens.log & ``` diff --git a/docs/tutorials/text_to_3d.md b/docs/tutorials/text_to_3d.md index 3a81366..0c4b0dc 100644 --- a/docs/tutorials/text_to_3d.md +++ b/docs/tutorials/text_to_3d.md @@ -74,8 +74,8 @@ You will get the following results: Kolors Model CLI (Supports Chinese & English Prompts): ```bash bash embodied_gen/scripts/textto3d.sh \ - --prompts "small bronze figurine of a lion" "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ - --output_root outputs/textto3d_k + --prompts "A globe with wooden base and latitude and longitude lines" "橙色电动手钻,有磨损细节" \ + --output_root outputs/textto3d_k ``` > Models with more permissive licenses can be found in `embodied_gen/models/image_comm_model.py`. diff --git a/embodied_gen/data/asset_converter.py b/embodied_gen/data/asset_converter.py index 71ef27e..7b09b70 100644 --- a/embodied_gen/data/asset_converter.py +++ b/embodied_gen/data/asset_converter.py @@ -1,3 +1,20 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + from __future__ import annotations import logging diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py index 5908013..92c0a76 100644 --- a/embodied_gen/data/backproject_v2.py +++ b/embodied_gen/data/backproject_v2.py @@ -274,6 +274,7 @@ class TextureBacker: mask_thresh (float, optional): Threshold for visibility masks. smooth_texture (bool, optional): Apply post-processing to texture. inpaint_smooth (bool, optional): Apply inpainting smoothing. + mesh_post_process (bool, optional): False for preventing modification of vertices. Example: ```py @@ -308,6 +309,7 @@ def __init__( mask_thresh: float = 0.5, smooth_texture: bool = True, inpaint_smooth: bool = False, + mesh_post_process: bool = True, ) -> None: self.camera_params = camera_params self.renderer = None @@ -318,6 +320,7 @@ def __init__( self.mask_thresh = mask_thresh self.smooth_texture = smooth_texture self.inpaint_smooth = inpaint_smooth + self.mesh_post_process = mesh_post_process self.bake_angle_thresh = bake_angle_thresh self.bake_unreliable_kernel_size = int( @@ -668,7 +671,12 @@ def __call__( mesh, self.scale, self.center ) textured_mesh = save_mesh_with_mtl( - vertices, faces, uv_map, texture_np, output_path + vertices, + faces, + uv_map, + texture_np, + output_path, + mesh_process=self.mesh_post_process, ) return textured_mesh @@ -766,6 +774,7 @@ def parse_args(): help="Disable saving delight image", ) parser.add_argument("--n_max_faces", type=int, default=30000) + parser.add_argument("--no_mesh_post_process", action="store_true") args, unknown = parser.parse_known_args() return args @@ -856,6 +865,7 @@ def entrypoint( render_wh=args.resolution_hw, texture_wh=args.texture_wh, smooth_texture=not args.no_smooth_texture, + mesh_post_process=not args.no_mesh_post_process, ) textured_mesh = texture_backer(multiviews, mesh, args.output_path) diff --git a/embodied_gen/data/backproject_v3.py b/embodied_gen/data/backproject_v3.py index b22b497..81cea59 100644 --- a/embodied_gen/data/backproject_v3.py +++ b/embodied_gen/data/backproject_v3.py @@ -353,8 +353,8 @@ def parse_args(): parser.add_argument( "--distance", type=float, - default=5, - help="Camera distance (default: 5)", + default=4.5, + help="Camera distance (default: 4.5)", ) parser.add_argument( "--resolution_hw", @@ -400,8 +400,8 @@ def parse_args(): parser.add_argument( "--mesh_sipmlify_ratio", type=float, - default=0.9, - help="Mesh simplification ratio (default: 0.9)", + default=0.85, + help="Mesh simplification ratio (default: 0.85)", ) parser.add_argument( "--delight", action="store_true", help="Use delighting model." @@ -500,7 +500,7 @@ def entrypoint( faces = mesh.faces.astype(np.int32) vertices = vertices.astype(np.float32) - if not args.skip_fix_mesh and len(faces) > 10 * args.n_max_faces: + if not args.skip_fix_mesh: mesh_fixer = MeshFixer(vertices, faces, args.device) vertices, faces = mesh_fixer( filter_ratio=args.mesh_sipmlify_ratio, @@ -512,7 +512,7 @@ def entrypoint( if len(faces) > args.n_max_faces: mesh_fixer = MeshFixer(vertices, faces, args.device) vertices, faces = mesh_fixer( - filter_ratio=max(0.05, args.mesh_sipmlify_ratio - 0.2), + filter_ratio=max(0.1, args.mesh_sipmlify_ratio - 0.1), max_hole_size=0.04, resolution=1024, num_views=1000, diff --git a/embodied_gen/data/utils.py b/embodied_gen/data/utils.py index fa2f7d5..74f96c6 100644 --- a/embodied_gen/data/utils.py +++ b/embodied_gen/data/utils.py @@ -15,10 +15,13 @@ # permissions and limitations under the License. +import logging import math import os -import random +import time import zipfile +from contextlib import contextmanager +from dataclasses import dataclass, field from shutil import rmtree from typing import List, Tuple, Union @@ -28,20 +31,9 @@ import nvdiffrast.torch as dr import torch import torch.nn.functional as F -from PIL import Image, ImageEnhance - -try: - from kolors.models.modeling_chatglm import ChatGLMModel - from kolors.models.tokenization_chatglm import ChatGLMTokenizer -except ImportError: - ChatGLMTokenizer = None - ChatGLMModel = None -import logging -from dataclasses import dataclass, field - import trimesh from kaolin.render.camera import Camera -from torch import nn +from PIL import Image, ImageEnhance logger = logging.getLogger(__name__) @@ -50,10 +42,8 @@ "DiffrastRender", "save_images", "render_pbr", - "prelabel_text_feature", "calc_vertex_normals", "normalize_vertices_array", - "load_mesh_to_unit_cube", "as_list", "CameraSetting", "import_kaolin_mesh", @@ -67,6 +57,7 @@ "trellis_preprocess", "delete_dir", "kaolin_to_opencv_view", + "model_device_ctx", ] @@ -520,114 +511,6 @@ def render_pbr( return image, albedo, diffuse, normal -def _move_to_target_device(data, device: str): - if isinstance(data, dict): - for key, value in data.items(): - data[key] = _move_to_target_device(value, device) - elif isinstance(data, torch.Tensor): - return data.to(device) - - return data - - -def _encode_prompt( - prompt_batch, - text_encoders, - tokenizers, - proportion_empty_prompts=0, - is_train=True, -): - prompt_embeds_list = [] - - captions = [] - for caption in prompt_batch: - if random.random() < proportion_empty_prompts: - captions.append("") - elif isinstance(caption, str): - captions.append(caption) - elif isinstance(caption, (list, np.ndarray)): - captions.append(random.choice(caption) if is_train else caption[0]) - - with torch.no_grad(): - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - captions, - padding="max_length", - max_length=256, - truncation=True, - return_tensors="pt", - ).to(text_encoder.device) - - output = text_encoder( - input_ids=text_inputs.input_ids, - attention_mask=text_inputs.attention_mask, - position_ids=text_inputs.position_ids, - output_hidden_states=True, - ) - - # We are only interested in the pooled output of the text encoder. - prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone() - pooled_prompt_embeds = output.hidden_states[-1][-1, :, :].clone() - bs_embed, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) - - return prompt_embeds, pooled_prompt_embeds - - -def load_llm_models(pretrained_model_name_or_path: str, device: str): - tokenizer = ChatGLMTokenizer.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - ) - text_encoder = ChatGLMModel.from_pretrained( - pretrained_model_name_or_path, - subfolder="text_encoder", - ).to(device) - - text_encoders = [ - text_encoder, - ] - tokenizers = [ - tokenizer, - ] - - logger.info(f"Load model from {pretrained_model_name_or_path} done.") - - return tokenizers, text_encoders - - -def prelabel_text_feature( - prompt_batch: List[str], - output_dir: str, - tokenizers: nn.Module, - text_encoders: nn.Module, -) -> List[str]: - os.makedirs(output_dir, exist_ok=True) - - # prompt_batch ["text..."] - prompt_embeds, pooled_prompt_embeds = _encode_prompt( - prompt_batch, text_encoders, tokenizers - ) - - prompt_embeds = _move_to_target_device(prompt_embeds, device="cpu") - pooled_prompt_embeds = _move_to_target_device( - pooled_prompt_embeds, device="cpu" - ) - - data_dict = dict( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds - ) - - save_path = os.path.join(output_dir, "text_feat.pth") - torch.save(data_dict, save_path) - - return save_path - - def _calc_face_normals( vertices: torch.Tensor, # V,3 first vertex may be unreferenced faces: torch.Tensor, # F,3 long, first face may be all zero @@ -683,25 +566,6 @@ def normalize_vertices_array( return vertices, scale, center -def load_mesh_to_unit_cube( - mesh_file: str, - mesh_scale: float = 1.0, -) -> tuple[trimesh.Trimesh, float, list[float]]: - if not os.path.exists(mesh_file): - raise FileNotFoundError(f"mesh_file path {mesh_file} not exists.") - - mesh = trimesh.load(mesh_file) - if isinstance(mesh, trimesh.Scene): - mesh = trimesh.utils.concatenate(mesh) - - vertices, scale, center = normalize_vertices_array( - mesh.vertices, mesh_scale - ) - mesh.vertices = vertices - - return mesh, scale, center - - def as_list(obj): if isinstance(obj, (list, tuple)): return obj @@ -862,6 +726,7 @@ def save_mesh_with_mtl( texture: Union[Image.Image, np.ndarray], output_path: str, material_base=(250, 250, 250, 255), + mesh_process: bool = True, ) -> trimesh.Trimesh: if isinstance(texture, np.ndarray): texture = Image.fromarray(texture) @@ -870,6 +735,7 @@ def save_mesh_with_mtl( vertices, faces, visual=trimesh.visual.TextureVisuals(uv=uvs, image=texture), + process=mesh_process, # True for preventing modification of vertices ) mesh.visual.material = trimesh.visual.material.SimpleMaterial( image=texture, @@ -998,8 +864,9 @@ def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor: def resize_pil(image: Image.Image, max_size: int = 1024) -> Image.Image: - max_size = max(image.size) - scale = min(1, 1024 / max_size) + current_max_dim = max(image.size) + scale = min(1, max_size / current_max_dim) + if scale < 1: new_size = (int(image.width * scale), int(image.height * scale)) image = image.resize(new_size, Image.Resampling.LANCZOS) @@ -1068,3 +935,34 @@ def delete_dir(folder_path: str, keep_subs: list[str] = None) -> None: rmtree(item_path) else: os.remove(item_path) + + +@contextmanager +def model_device_ctx( + *models, + src_device: str = "cpu", + dst_device: str = "cuda", + verbose: bool = False, +): + start = time.perf_counter() + for m in models: + if m is None: + continue + m.to(dst_device) + to_cuda_time = time.perf_counter() - start + + try: + yield + finally: + start = time.perf_counter() + for m in models: + if m is None: + continue + m.to(src_device) + to_cpu_time = time.perf_counter() - start + + if verbose: + model_names = [m.__class__.__name__ for m in models] + logger.debug( + f"[model_device_ctx] {model_names} to cuda: {to_cuda_time:.1f}s, to cpu: {to_cpu_time:.1f}s" + ) diff --git a/embodied_gen/models/sam3d.py b/embodied_gen/models/sam3d.py new file mode 100644 index 0000000..4b28e40 --- /dev/null +++ b/embodied_gen/models/sam3d.py @@ -0,0 +1,152 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +from embodied_gen.utils.monkey_patches import monkey_patch_sam3d + +monkey_patch_sam3d() +import os +import sys + +import numpy as np +from hydra.utils import instantiate +from modelscope import snapshot_download +from omegaconf import OmegaConf +from PIL import Image + +current_file_path = os.path.abspath(__file__) +current_dir = os.path.dirname(current_file_path) +sys.path.append(os.path.join(current_dir, "../..")) +from loguru import logger +from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import ( + InferencePipelinePointMap, +) + +logger.remove() +logger.add(lambda _: None, level="ERROR") + + +__all__ = ["Sam3dInference"] + + +class Sam3dInference: + """Wrapper for the SAM-3D-Objects inference pipeline. + + This class handles loading the SAM-3D-Objects model, configuring it for inference, + and running the pipeline on input images (optionally with masks and pointmaps). + It supports distillation options and inference step customization. + + Args: + local_dir (str): Directory to store or load model weights and configs. + compile (bool): Whether to compile the model for faster inference. + + Methods: + merge_mask_to_rgba(image, mask): + Merges a binary mask into the alpha channel of an RGB image. + + run(image, mask=None, seed=None, pointmap=None, use_stage1_distillation=False, + use_stage2_distillation=False, stage1_inference_steps=25, stage2_inference_steps=25): + Runs the inference pipeline and returns the output dictionary. + """ + + def __init__( + self, local_dir: str = "weights/sam-3d-objects", compile: bool = False + ) -> None: + if not os.path.exists(local_dir): + snapshot_download("facebook/sam-3d-objects", local_dir=local_dir) + config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml") + config = OmegaConf.load(config_file) + config.rendering_engine = "nvdiffrast" + config.compile_model = compile + config.workspace_dir = os.path.dirname(config_file) + # Generate 4 instead of 32 gs in each pixel for efficient storage. + config["slat_decoder_gs_config_path"] = config.pop( + "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml" + ) + config["slat_decoder_gs_ckpt_path"] = config.pop( + "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt" + ) + self.pipeline: InferencePipelinePointMap = instantiate(config) + + def merge_mask_to_rgba( + self, image: np.ndarray, mask: np.ndarray + ) -> np.ndarray: + mask = mask.astype(np.uint8) * 255 + mask = mask[..., None] + rgba_image = np.concatenate([image[..., :3], mask], axis=-1) + + return rgba_image + + def run( + self, + image: np.ndarray | Image.Image, + mask: np.ndarray = None, + seed: int = None, + pointmap: np.ndarray = None, + use_stage1_distillation: bool = False, + use_stage2_distillation: bool = False, + stage1_inference_steps: int = 25, + stage2_inference_steps: int = 25, + ) -> dict: + if isinstance(image, Image.Image): + image = np.array(image) + if mask is not None: + image = self.merge_mask_to_rgba(image, mask) + return self.pipeline.run( + image, + None, + seed, + stage1_only=False, + with_mesh_postprocess=False, + with_texture_baking=False, + with_layout_postprocess=False, + use_vertex_color=True, + use_stage1_distillation=use_stage1_distillation, + use_stage2_distillation=use_stage2_distillation, + stage1_inference_steps=stage1_inference_steps, + stage2_inference_steps=stage2_inference_steps, + pointmap=pointmap, + ) + + +if __name__ == "__main__": + pipeline = Sam3dInference() + + from time import time + + import torch + from embodied_gen.models.segment_model import RembgRemover + + input_image = "apps/assets/example_image/sample_00.jpg" + output_gs = "outputs/splat.ply" + remover = RembgRemover() + clean_image = remover(input_image) + + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + torch.cuda.empty_cache() + + start = time() + output = pipeline.run(clean_image, seed=42) + print(f"Running cost: {round(time()-start, 1)}") + + if torch.cuda.is_available(): + max_memory = torch.cuda.max_memory_allocated() / (1024**3) + print(f"(Max VRAM): {max_memory:.2f} GB") + + print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") + + output["gs"].save_ply(output_gs) + print(f"Saved to {output_gs}") diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py index 6f54cfa..a82afa6 100644 --- a/embodied_gen/models/segment_model.py +++ b/embodied_gen/models/segment_model.py @@ -43,6 +43,7 @@ "SAMRemover", "SAMPredictor", "RembgRemover", + "BMGG14Remover", "get_segmented_image_by_agent", ] @@ -376,7 +377,7 @@ def __init__(self) -> None: def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None - ): + ) -> Image.Image: """Removes background from an image. Args: @@ -496,13 +497,18 @@ def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: # input_image = "outputs/text2image/tmp/bucket.jpeg" # output_image = "outputs/text2image/tmp/bucket_seg.png" - remover = SAMRemover(model_type="vit_h") - remover = RembgRemover() - clean_image = remover(input_image) - clean_image.save(output_image) - get_segmented_image_by_agent( - Image.open(input_image), remover, remover, None, "./test_seg.png" - ) + # remover = SAMRemover(model_type="vit_h") + # remover = RembgRemover() + # clean_image = remover(input_image) + # clean_image.save(output_image) + # get_segmented_image_by_agent( + # Image.open(input_image), remover, remover, None, "./test_seg.png" + # ) remover = BMGG14Remover() - remover("embodied_gen/models/test_seg.jpg", "./seg.png") + clean_image = remover("./camera.jpeg", "./seg.png") + from embodied_gen.utils.process_media import ( + keep_largest_connected_component, + ) + + keep_largest_connected_component(clean_image).save("./seg_post.png") diff --git a/embodied_gen/scripts/gen_scene3d.py b/embodied_gen/scripts/gen_scene3d.py index 42d2527..c1ab5ff 100644 --- a/embodied_gen/scripts/gen_scene3d.py +++ b/embodied_gen/scripts/gen_scene3d.py @@ -1,3 +1,20 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + import logging import os import random diff --git a/embodied_gen/scripts/gen_texture.py b/embodied_gen/scripts/gen_texture.py index a0023a8..d28336f 100644 --- a/embodied_gen/scripts/gen_texture.py +++ b/embodied_gen/scripts/gen_texture.py @@ -1,3 +1,20 @@ +# Project EmbodiedGen +# +# Copyright (c) 2025 Horizon Robotics. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + import os import shutil from dataclasses import dataclass @@ -94,6 +111,7 @@ def entrypoint() -> None: delight=cfg.delight, no_save_delight_img=True, texture_wh=[cfg.texture_size, cfg.texture_size], + no_mesh_post_process=True, ) drender_api( mesh_path=f"{output_root}/texture_mesh/{uuid}.obj", diff --git a/embodied_gen/scripts/imageto3d.py b/embodied_gen/scripts/imageto3d.py index 61a14b6..dfbc8e4 100644 --- a/embodied_gen/scripts/imageto3d.py +++ b/embodied_gen/scripts/imageto3d.py @@ -14,30 +14,30 @@ # implied. See the License for the specific language governing # permissions and limitations under the License. - import argparse import os import random -import sys from glob import glob from shutil import copy, copytree, rmtree import numpy as np -import torch import trimesh from PIL import Image from embodied_gen.data.backproject_v3 import entrypoint as backproject_api -from embodied_gen.data.utils import delete_dir, trellis_preprocess +from embodied_gen.data.utils import delete_dir +# from embodied_gen.models.sr_model import ImageRealESRGAN # from embodied_gen.models.delight_model import DelightingModel from embodied_gen.models.gs_model import GaussianOperator from embodied_gen.models.segment_model import RembgRemover - -# from embodied_gen.models.sr_model import ImageRealESRGAN from embodied_gen.scripts.render_gs import entrypoint as render_gs_api from embodied_gen.utils.gpt_clients import GPT_CLIENT +from embodied_gen.utils.inference import image3d_model_infer from embodied_gen.utils.log import logger -from embodied_gen.utils.process_media import merge_images_video +from embodied_gen.utils.process_media import ( + combine_images_to_grid, + merge_images_video, +) from embodied_gen.utils.tags import VERSION from embodied_gen.utils.trender import render_video from embodied_gen.validators.quality_checkers import ( @@ -48,26 +48,24 @@ ) from embodied_gen.validators.urdf_convertor import URDFGenerator -current_file_path = os.path.abspath(__file__) -current_dir = os.path.dirname(current_file_path) -sys.path.append(os.path.join(current_dir, "../..")) -from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline +# random.seed(0) +IMAGE3D_MODEL = "SAM3D" # TRELLIS or SAM3D +logger.info(f"Loading {IMAGE3D_MODEL} as Image3D Models...") +if IMAGE3D_MODEL == "TRELLIS": + from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline -os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( - "~/.cache/torch_extensions" -) -os.environ["GRADIO_ANALYTICS_ENABLED"] = "false" -os.environ["SPCONV_ALGO"] = "native" -random.seed(0) + PIPELINE = TrellisImageTo3DPipeline.from_pretrained( + "microsoft/TRELLIS-image-large" + ) + # PIPELINE.cuda() +elif IMAGE3D_MODEL == "SAM3D": + from embodied_gen.models.sam3d import Sam3dInference + + PIPELINE = Sam3dInference() -logger.info("Loading Image3D Models...") # DELIGHT = DelightingModel() # IMAGESR_MODEL = ImageRealESRGAN(outscale=4) RBG_REMOVER = RembgRemover() -PIPELINE = TrellisImageTo3DPipeline.from_pretrained( - "microsoft/TRELLIS-image-large" -) -# PIPELINE.cuda() SEG_CHECKER = ImageSegChecker(GPT_CLIENT) GEO_CHECKER = MeshGeoChecker(GPT_CLIENT) AESTHETIC_CHECKER = ImageAestheticChecker() @@ -151,7 +149,6 @@ def entrypoint(**kwargs): # Segmentation: Get segmented image using Rembg. seg_path = f"{output_root}/{filename}_cond.png" seg_image = RBG_REMOVER(image) if image.mode != "RGBA" else image - seg_image = trellis_preprocess(seg_image) seg_image.save(seg_path) seed = args.seed @@ -162,27 +159,8 @@ def entrypoint(**kwargs): logger.info( f"Try: {try_idx + 1}/{args.n_retry}, Seed: {seed}, Prompt: {seg_path}" ) - # Run the pipeline try: - PIPELINE.cuda() - outputs = PIPELINE.run( - seg_image, - preprocess_image=False, - seed=( - random.randint(0, 100000) if seed is None else seed - ), - # Optional parameters - # sparse_structure_sampler_params={ - # "steps": 12, - # "cfg_strength": 7.5, - # }, - # slat_sampler_params={ - # "steps": 12, - # "cfg_strength": 3, - # }, - ) - PIPELINE.cpu() - torch.cuda.empty_cache() + outputs = image3d_model_infer(PIPELINE, seg_image, seed) except Exception as e: logger.error( f"[Pipeline Failed] process {image_path}: {e}, skip." @@ -215,14 +193,13 @@ def entrypoint(**kwargs): render_gs_api( input_gs=aligned_gs_path, output_path=color_path, - elevation=[20, -10, 60, -50], - num_images=12, + elevation=[30, -30], + num_images=4, ) - color_img = Image.open(color_path) - keep_height = int(color_img.height * 2 / 3) - crop_img = color_img.crop((0, 0, color_img.width, keep_height)) - geo_flag, geo_result = GEO_CHECKER([crop_img], text=asset_node) + geo_flag, geo_result = GEO_CHECKER( + [color_img], text=asset_node + ) logger.warning( f"{GEO_CHECKER.__class__.__name__}: {geo_result} for {seg_path}" ) @@ -232,8 +209,8 @@ def entrypoint(**kwargs): seed = random.randint(0, 100000) if seed is not None else None # Render the video for generated 3D asset. - color_images = render_video(gs_model)["color"] - normal_images = render_video(mesh_model)["normal"] + color_images = render_video(gs_model, r=1.85)["color"] + normal_images = render_video(mesh_model, r=1.85)["normal"] video_path = os.path.join(output_root, "gs_mesh.mp4") merge_images_video(color_images, normal_images, video_path) @@ -312,7 +289,7 @@ def entrypoint(**kwargs): image_paths = glob(f"{image_dir}/*.png") images_list = [] for checker in CHECKERS: - images = image_paths + images = combine_images_to_grid(image_paths) if isinstance(checker, ImageSegChecker): images = [ f"{output_root}/{filename}_raw.png", @@ -334,9 +311,12 @@ def entrypoint(**kwargs): f"{result_dir}/{urdf_convertor.output_mesh_dir}", ) copy(video_path, f"{result_dir}/video.mp4") + if not args.keep_intermediate: delete_dir(output_root, keep_subs=["result"]) + logger.info(f"Saved results for {image_path} in {result_dir}") + except Exception as e: logger.error(f"Failed to process {image_path}: {e}, skip.") continue diff --git a/embodied_gen/scripts/render_gs.py b/embodied_gen/scripts/render_gs.py index 3a3d7a2..e00c548 100644 --- a/embodied_gen/scripts/render_gs.py +++ b/embodied_gen/scripts/render_gs.py @@ -27,7 +27,6 @@ from embodied_gen.data.utils import ( CameraSetting, init_kal_camera, - normalize_vertices_array, ) from embodied_gen.models.gs_model import load_gs_model from embodied_gen.utils.process_media import combine_images_to_grid diff --git a/embodied_gen/scripts/textto3d.py b/embodied_gen/scripts/textto3d.py index 4e96063..c5fe5f3 100644 --- a/embodied_gen/scripts/textto3d.py +++ b/embodied_gen/scripts/textto3d.py @@ -30,6 +30,7 @@ from embodied_gen.utils.log import logger from embodied_gen.utils.process_media import ( check_object_edge_truncated, + combine_images_to_grid, render_asset3d, ) from embodied_gen.validators.quality_checkers import ( @@ -51,7 +52,6 @@ __all__ = [ - "text_to_image", "text_to_3d", ] @@ -176,12 +176,12 @@ def text_to_3d(**kwargs) -> dict: image_path = render_asset3d( mesh_path, output_root=f"{node_save_dir}/result", - num_images=6, + num_images=4, elevation=(30, -30), output_subdir="renders", no_index_file=True, ) - + image_path = combine_images_to_grid(image_path) check_text = asset_type if asset_type is not None else prompt qa_flag, qa_result = TXTGEN_CHECKER(check_text, image_path) logger.warning( diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py index 47f5ce2..32a9ea9 100644 --- a/embodied_gen/utils/gpt_clients.py +++ b/embodied_gen/utils/gpt_clients.py @@ -21,13 +21,14 @@ from io import BytesIO from typing import Optional +import openai import yaml from openai import AzureOpenAI, OpenAI # pip install openai from PIL import Image from tenacity import ( retry, + retry_if_not_exception_type, stop_after_attempt, - stop_after_delay, wait_random_exponential, ) from embodied_gen.utils.process_media import combine_images_to_grid @@ -106,8 +107,9 @@ def __init__( logger.info(f"Using GPT model: {self.model_name}.") @retry( - wait=wait_random_exponential(min=1, max=20), - stop=(stop_after_attempt(10) | stop_after_delay(30)), + retry=retry_if_not_exception_type(openai.BadRequestError), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(5), ) def completion_with_backoff(self, **kwargs): """Performs a chat completion request with retry/backoff.""" @@ -246,3 +248,8 @@ def check_connection(self) -> None: model_name=model_name, check_connection=False, ) + + +if __name__ == "__main__": + response = GPT_CLIENT.query("What is the capital of China?") + print(f"Response: {response}") diff --git a/embodied_gen/utils/inference.py b/embodied_gen/utils/inference.py new file mode 100644 index 0000000..5e19e93 --- /dev/null +++ b/embodied_gen/utils/inference.py @@ -0,0 +1,59 @@ +from embodied_gen.utils.monkey_patches import monkey_path_trellis + +monkey_path_trellis() +import random + +import torch +from PIL import Image +from embodied_gen.data.utils import trellis_preprocess +from embodied_gen.models.sam3d import Sam3dInference +from embodied_gen.utils.trender import pack_state, unpack_state +from thirdparty.TRELLIS.trellis.pipelines import TrellisImageTo3DPipeline + +__all__ = [ + "image3d_model_infer", +] + + +def image3d_model_infer( + pipe: TrellisImageTo3DPipeline | Sam3dInference, + seg_image: Image.Image, + seed: int = None, + **kwargs: dict, +) -> dict[str, any]: + if isinstance(pipe, TrellisImageTo3DPipeline): + pipe.cuda() + seg_image = trellis_preprocess(seg_image) + outputs = pipe.run( + seg_image, + preprocess_image=False, + seed=(random.randint(0, 100000) if seed is None else seed), + # Optional parameters + # sparse_structure_sampler_params={ + # "steps": 12, + # "cfg_strength": 7.5, + # }, + # slat_sampler_params={ + # "steps": 12, + # "cfg_strength": 3, + # }, + **kwargs, + ) + pipe.cpu() + elif isinstance(pipe, Sam3dInference): + outputs = pipe.run( + seg_image, + seed=(random.randint(0, 100000) if seed is None else seed), + # stage1_inference_steps=25, + # stage2_inference_steps=25, + **kwargs, + ) + state = pack_state(outputs["gaussian"][0], outputs["mesh"][0]) + # Align GS3D from SAM3D with TRELLIS format. + outputs["gaussian"][0], _ = unpack_state(state, device="cuda") + else: + raise ValueError(f"Unsupported pipeline type: {type(pipe)}") + + torch.cuda.empty_cache() + + return outputs diff --git a/embodied_gen/utils/monkey_patches.py b/embodied_gen/utils/monkey_patches.py index b5d35cf..6e3b033 100644 --- a/embodied_gen/utils/monkey_patches.py +++ b/embodied_gen/utils/monkey_patches.py @@ -25,6 +25,73 @@ from PIL import Image from torchvision import transforms +__all__ = [ + "monkey_patch_pano2room", + "monkey_patch_maniskill", + "monkey_patch_sam3d", +] + + +def monkey_path_trellis(): + import torch.nn.functional as F + + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + sys.path.append(os.path.join(current_dir, "../..")) + + from thirdparty.TRELLIS.trellis.representations import Gaussian + from thirdparty.TRELLIS.trellis.representations.gaussian.general_utils import ( + build_scaling_rotation, + inverse_sigmoid, + strip_symmetric, + ) + + os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( + "~/.cache/torch_extensions" + ) + os.environ["SPCONV_ALGO"] = "auto" # Can be 'native' or 'auto' + os.environ['ATTN_BACKEND'] = ( + "xformers" # Can be 'flash-attn' or 'xformers' + ) + from thirdparty.TRELLIS.trellis.modules.sparse import set_attn + + set_attn("xformers") + + def patched_setup_functions(self): + def inverse_softplus(x): + return x + torch.log(-torch.expm1(-x)) + + def build_covariance_from_scaling_rotation( + scaling, scaling_modifier, rotation + ): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + if self.scaling_activation_type == "exp": + self.scaling_activation = torch.exp + self.inverse_scaling_activation = torch.log + elif self.scaling_activation_type == "softplus": + self.scaling_activation = F.softplus + self.inverse_scaling_activation = inverse_softplus + + self.covariance_activation = build_covariance_from_scaling_rotation + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + self.rotation_activation = F.normalize + + self.scale_bias = self.inverse_scaling_activation( + torch.tensor(self.scaling_bias) + ).to(self.device) + self.rots_bias = torch.zeros((4)).to(self.device) + self.rots_bias[0] = 1 + self.opacity_bias = self.inverse_opacity_activation( + torch.tensor(self.opacity_bias) + ).to(self.device) + + Gaussian.setup_functions = patched_setup_functions + def monkey_patch_pano2room(): current_file_path = os.path.abspath(__file__) @@ -216,3 +283,363 @@ def get_rgba_tensor(camera, return_alpha): ManiSkillScene.get_human_render_camera_images = ( get_human_render_camera_images ) + + +def monkey_patch_sam3d(): + from typing import Optional, Union + + from embodied_gen.data.utils import model_device_ctx + from embodied_gen.utils.log import logger + + os.environ["LIDRA_SKIP_INIT"] = "true" + + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + sam3d_root = os.path.abspath( + os.path.join(current_dir, "../../thirdparty/sam3d") + ) + if sam3d_root not in sys.path: + sys.path.insert(0, sam3d_root) + + def patch_pointmap_infer_pipeline(): + from copy import deepcopy + + try: + from sam3d_objects.pipeline.inference_pipeline_pointmap import ( + InferencePipelinePointMap, + ) + except ImportError: + logger.error( + "[MonkeyPatch]: Could not import sam3d_objects directly. Check paths." + ) + return + + def patch_run( + self, + image: Union[None, Image.Image, np.ndarray], + mask: Union[None, Image.Image, np.ndarray] = None, + seed: Optional[int] = None, + stage1_only=False, + with_mesh_postprocess=True, + with_texture_baking=True, + with_layout_postprocess=True, + use_vertex_color=False, + stage1_inference_steps=None, + stage2_inference_steps=None, + use_stage1_distillation=False, + use_stage2_distillation=False, + pointmap=None, + decode_formats=None, + estimate_plane=False, + ) -> dict: + image = self.merge_image_and_mask(image, mask) + with self.device: + pointmap_dict = self.compute_pointmap(image, pointmap) + pointmap = pointmap_dict["pointmap"] + pts = type(self)._down_sample_img(pointmap) + pts_colors = type(self)._down_sample_img( + pointmap_dict["pts_color"] + ) + + if estimate_plane: + return self.estimate_plane(pointmap_dict, image) + + ss_input_dict = self.preprocess_image( + image, self.ss_preprocessor, pointmap=pointmap + ) + + slat_input_dict = self.preprocess_image( + image, self.slat_preprocessor + ) + if seed is not None: + torch.manual_seed(seed) + + with model_device_ctx( + self.models["ss_generator"], + self.models["ss_decoder"], + self.condition_embedders["ss_condition_embedder"], + ): + ss_return_dict = self.sample_sparse_structure( + ss_input_dict, + inference_steps=stage1_inference_steps, + use_distillation=use_stage1_distillation, + ) + + # We could probably use the decoder from the models themselves + pointmap_scale = ss_input_dict.get("pointmap_scale", None) + pointmap_shift = ss_input_dict.get("pointmap_shift", None) + ss_return_dict.update( + self.pose_decoder( + ss_return_dict, + scene_scale=pointmap_scale, + scene_shift=pointmap_shift, + ) + ) + + ss_return_dict["scale"] = ( + ss_return_dict["scale"] + * ss_return_dict["downsample_factor"] + ) + + if stage1_only: + logger.info("Finished!") + ss_return_dict["voxel"] = ( + ss_return_dict["coords"][:, 1:] / 64 - 0.5 + ) + return { + **ss_return_dict, + "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 + "pointmap_colors": pts_colors.cpu().permute( + (1, 2, 0) + ), # HxWx3 + } + # return ss_return_dict + + coords = ss_return_dict["coords"] + with model_device_ctx( + self.models["slat_generator"], + self.condition_embedders["slat_condition_embedder"], + ): + slat = self.sample_slat( + slat_input_dict, + coords, + inference_steps=stage2_inference_steps, + use_distillation=use_stage2_distillation, + ) + + with model_device_ctx( + self.models["slat_decoder_mesh"], + self.models["slat_decoder_gs"], + self.models["slat_decoder_gs_4"], + ): + outputs = self.decode_slat( + slat, + ( + self.decode_formats + if decode_formats is None + else decode_formats + ), + ) + + outputs = self.postprocess_slat_output( + outputs, + with_mesh_postprocess, + with_texture_baking, + use_vertex_color, + ) + glb = outputs.get("glb", None) + + try: + if ( + with_layout_postprocess + and self.layout_post_optimization_method is not None + ): + assert ( + glb is not None + ), "require mesh to run postprocessing" + logger.info( + "Running layout post optimization method..." + ) + postprocessed_pose = self.run_post_optimization( + deepcopy(glb), + pointmap_dict["intrinsics"], + ss_return_dict, + ss_input_dict, + ) + ss_return_dict.update(postprocessed_pose) + except Exception as e: + logger.error( + f"Error during layout post optimization: {e}", + exc_info=True, + ) + + result = { + **ss_return_dict, + **outputs, + "pointmap": pts.cpu().permute((1, 2, 0)), + "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), + } + return result + + InferencePipelinePointMap.run = patch_run + + def patch_infer_init(): + import torch + + try: + from sam3d_objects.pipeline import preprocess_utils + from sam3d_objects.pipeline.inference_pipeline_pointmap import ( + InferencePipeline, + ) + from sam3d_objects.pipeline.inference_utils import ( + SLAT_MEAN, + SLAT_STD, + ) + except ImportError: + print( + "[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline." + ) + return + + def patch_init( + self, + ss_generator_config_path, + ss_generator_ckpt_path, + slat_generator_config_path, + slat_generator_ckpt_path, + ss_decoder_config_path, + ss_decoder_ckpt_path, + slat_decoder_gs_config_path, + slat_decoder_gs_ckpt_path, + slat_decoder_mesh_config_path, + slat_decoder_mesh_ckpt_path, + slat_decoder_gs_4_config_path=None, + slat_decoder_gs_4_ckpt_path=None, + ss_encoder_config_path=None, + ss_encoder_ckpt_path=None, + decode_formats=["gaussian", "mesh"], + dtype="bfloat16", + pad_size=1.0, + version="v0", + device="cuda", + ss_preprocessor=preprocess_utils.get_default_preprocessor(), + slat_preprocessor=preprocess_utils.get_default_preprocessor(), + ss_condition_input_mapping=["image"], + slat_condition_input_mapping=["image"], + pose_decoder_name="default", + workspace_dir="", + downsample_ss_dist=0, # the distance we use to downsample + ss_inference_steps=25, + ss_rescale_t=3, + ss_cfg_strength=7, + ss_cfg_interval=[0, 500], + ss_cfg_strength_pm=0.0, + slat_inference_steps=25, + slat_rescale_t=3, + slat_cfg_strength=5, + slat_cfg_interval=[0, 500], + rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d, + shape_model_dtype=None, + compile_model=False, + slat_mean=SLAT_MEAN, + slat_std=SLAT_STD, + ): + self.rendering_engine = rendering_engine + self.device = torch.device(device) + self.compile_model = compile_model + with self.device: + self.decode_formats = decode_formats + self.pad_size = pad_size + self.version = version + self.ss_condition_input_mapping = ss_condition_input_mapping + self.slat_condition_input_mapping = ( + slat_condition_input_mapping + ) + self.workspace_dir = workspace_dir + self.downsample_ss_dist = downsample_ss_dist + self.ss_inference_steps = ss_inference_steps + self.ss_rescale_t = ss_rescale_t + self.ss_cfg_strength = ss_cfg_strength + self.ss_cfg_interval = ss_cfg_interval + self.ss_cfg_strength_pm = ss_cfg_strength_pm + self.slat_inference_steps = slat_inference_steps + self.slat_rescale_t = slat_rescale_t + self.slat_cfg_strength = slat_cfg_strength + self.slat_cfg_interval = slat_cfg_interval + + self.dtype = self._get_dtype(dtype) + if shape_model_dtype is None: + self.shape_model_dtype = self.dtype + else: + self.shape_model_dtype = self._get_dtype(shape_model_dtype) + + # Setup preprocessors + self.pose_decoder = self.init_pose_decoder( + ss_generator_config_path, pose_decoder_name + ) + self.ss_preprocessor = self.init_ss_preprocessor( + ss_preprocessor, ss_generator_config_path + ) + self.slat_preprocessor = slat_preprocessor + + raw_device = self.device + self.device = torch.device("cpu") + ss_generator = self.init_ss_generator( + ss_generator_config_path, ss_generator_ckpt_path + ) + slat_generator = self.init_slat_generator( + slat_generator_config_path, slat_generator_ckpt_path + ) + ss_decoder = self.init_ss_decoder( + ss_decoder_config_path, ss_decoder_ckpt_path + ) + ss_encoder = self.init_ss_encoder( + ss_encoder_config_path, ss_encoder_ckpt_path + ) + slat_decoder_gs = self.init_slat_decoder_gs( + slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path + ) + slat_decoder_gs_4 = self.init_slat_decoder_gs( + slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path + ) + slat_decoder_mesh = self.init_slat_decoder_mesh( + slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path + ) + + # Load conditioner embedder so that we only load it once + ss_condition_embedder = self.init_ss_condition_embedder( + ss_generator_config_path, ss_generator_ckpt_path + ) + slat_condition_embedder = self.init_slat_condition_embedder( + slat_generator_config_path, slat_generator_ckpt_path + ) + self.device = raw_device + + self.condition_embedders = { + "ss_condition_embedder": ss_condition_embedder, + "slat_condition_embedder": slat_condition_embedder, + } + + # override generator and condition embedder setting + self.override_ss_generator_cfg_config( + ss_generator, + cfg_strength=ss_cfg_strength, + inference_steps=ss_inference_steps, + rescale_t=ss_rescale_t, + cfg_interval=ss_cfg_interval, + cfg_strength_pm=ss_cfg_strength_pm, + ) + self.override_slat_generator_cfg_config( + slat_generator, + cfg_strength=slat_cfg_strength, + inference_steps=slat_inference_steps, + rescale_t=slat_rescale_t, + cfg_interval=slat_cfg_interval, + ) + + self.models = torch.nn.ModuleDict( + { + "ss_generator": ss_generator, + "slat_generator": slat_generator, + "ss_encoder": ss_encoder, + "ss_decoder": ss_decoder, + "slat_decoder_gs": slat_decoder_gs, + "slat_decoder_gs_4": slat_decoder_gs_4, + "slat_decoder_mesh": slat_decoder_mesh, + } + ) + logger.info("Loading SAM3D model weights completed.") + + if self.compile_model: + logger.info("Compiling model...") + self._compile() + logger.info("Model compilation completed!") + self.slat_mean = torch.tensor(slat_mean) + self.slat_std = torch.tensor(slat_std) + + InferencePipeline.__init__ = patch_init + + patch_pointmap_infer_pipeline() + patch_infer_init() + + return diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py index 8feb7ec..3a68ca1 100644 --- a/embodied_gen/utils/process_media.py +++ b/embodied_gen/utils/process_media.py @@ -96,7 +96,7 @@ def render_asset3d( image_paths = render_asset3d( mesh_path="path_to_mesh.obj", output_root="path_to_save_dir", - num_images=6, + num_images=4, elevation=(30, -30), output_subdir="renders", no_index_file=True, @@ -230,6 +230,29 @@ def filter_image_small_connected_components( return image +def keep_largest_connected_component(pil_img: Image.Image) -> Image.Image: + if pil_img.mode != "RGBA": + pil_img = pil_img.convert("RGBA") + + img_arr = np.array(pil_img) + alpha_channel = img_arr[:, :, 3] + + _, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY) + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + binary_mask, connectivity=8 + ) + if num_labels < 2: + return pil_img + + largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) + new_alpha = np.where(labels == largest_label, alpha_channel, 0).astype( + np.uint8 + ) + img_arr[:, :, 3] = new_alpha + + return Image.fromarray(img_arr) + + def combine_images_to_grid( images: list[str | Image.Image], cat_row_col: tuple[int, int] = None, @@ -439,7 +462,7 @@ def render( plt.axis("off") legend_handles = [ - Patch(facecolor=color, edgecolor='black', label=role) + Patch(facecolor=color, edgecolor="black", label=role) for role, color in self.role_colors.items() ] plt.legend( @@ -465,7 +488,7 @@ def load_scene_dict(file_path: str) -> dict: dict: Mapping from scene ID to description. """ scene_dict = {} - with open(file_path, "r", encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or ":" not in line: @@ -487,7 +510,7 @@ def is_image_file(filename: str) -> bool: """ mime_type, _ = mimetypes.guess_type(filename) - return mime_type is not None and mime_type.startswith('image') + return mime_type is not None and mime_type.startswith("image") def parse_text_prompts(prompts: list[str]) -> list[str]: diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py index 9302331..b03d010 100644 --- a/embodied_gen/utils/tags.py +++ b/embodied_gen/utils/tags.py @@ -1 +1 @@ -VERSION = "v0.1.6" +VERSION = "v0.1.7" diff --git a/embodied_gen/utils/trender.py b/embodied_gen/utils/trender.py index 53acc50..f2a845f 100644 --- a/embodied_gen/utils/trender.py +++ b/embodied_gen/utils/trender.py @@ -16,29 +16,35 @@ import os import sys +from collections import defaultdict import numpy as np import spaces import torch +from easydict import EasyDict as edict from tqdm import tqdm current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) -from thirdparty.TRELLIS.trellis.renderers.mesh_renderer import MeshRenderer -from thirdparty.TRELLIS.trellis.representations import MeshExtractResult +from thirdparty.TRELLIS.trellis.renderers import GaussianRenderer, MeshRenderer +from thirdparty.TRELLIS.trellis.representations import ( + Gaussian, + MeshExtractResult, +) from thirdparty.TRELLIS.trellis.utils.render_utils import ( - render_frames, yaw_pitch_r_fov_to_extrinsics_intrinsics, ) __all__ = [ "render_video", + "pack_state", + "unpack_state", ] @spaces.GPU -def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): +def render_mesh_frames(sample, extrinsics, intrinsics, options={}, **kwargs): renderer = MeshRenderer() renderer.rendering_options.resolution = options.get("resolution", 512) renderer.rendering_options.near = options.get("near", 1) @@ -60,6 +66,57 @@ def render_mesh(sample, extrinsics, intrinsics, options={}, **kwargs): return rets +@spaces.GPU +def render_gs_frames( + sample, + extrinsics, + intrinsics, + options=None, + colors_overwrite=None, + verbose=True, + **kwargs, +): + def to_img(tensor): + return np.clip( + tensor.detach().cpu().numpy().transpose(1, 2, 0) * 255, 0, 255 + ).astype(np.uint8) + + def to_numpy(tensor): + return tensor.detach().cpu().numpy() + + renderer = GaussianRenderer() + renderer.pipe.kernel_size = kwargs.get("kernel_size", 0.1) + renderer.pipe.use_mip_gaussian = True + + defaults = { + "resolution": 512, + "near": 0.8, + "far": 1.6, + "bg_color": (0, 0, 0), + "ssaa": 1, + } + final_options = {**defaults, **(options or {})} + + for k, v in final_options.items(): + if hasattr(renderer.rendering_options, k): + setattr(renderer.rendering_options, k, v) + + outputs = defaultdict(list) + iterator = zip(extrinsics, intrinsics) + if verbose: + iterator = tqdm(iterator, total=len(extrinsics), desc="Rendering") + + for extr, intr in iterator: + res = renderer.render( + sample, extr, intr, colors_overwrite=colors_overwrite + ) + outputs["color"].append(to_img(res["color"])) + depth = res.get("percent_depth") or res.get("depth") + outputs["depth"].append(to_numpy(depth) if depth is not None else None) + + return dict(outputs) + + @spaces.GPU def render_video( sample, @@ -77,7 +134,9 @@ def render_video( yaws, pitch, r, fov ) render_fn = ( - render_mesh if isinstance(sample, MeshExtractResult) else render_frames + render_mesh_frames + if sample.__class__.__name__ == "MeshExtractResult" + else render_gs_frames ) result = render_fn( sample, @@ -88,3 +147,47 @@ def render_video( ) return result + + +@spaces.GPU +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + "gaussian": { + **gs.init_params, + "_xyz": gs._xyz.cpu().numpy(), + "_features_dc": gs._features_dc.cpu().numpy(), + "_scaling": gs._scaling.cpu().numpy(), + "_rotation": gs._rotation.cpu().numpy(), + "_opacity": gs._opacity.cpu().numpy(), + }, + "mesh": { + "vertices": mesh.vertices.cpu().numpy(), + "faces": mesh.faces.cpu().numpy(), + }, + } + + +def unpack_state(state: dict, device: str = "cpu") -> tuple[Gaussian, dict]: + gs = Gaussian( + aabb=state["gaussian"]["aabb"], + sh_degree=state["gaussian"]["sh_degree"], + mininum_kernel_size=state["gaussian"]["mininum_kernel_size"], + scaling_bias=state["gaussian"]["scaling_bias"], + opacity_bias=state["gaussian"]["opacity_bias"], + scaling_activation=state["gaussian"]["scaling_activation"], + device=device, + ) + gs._xyz = torch.tensor(state["gaussian"]["_xyz"], device=device) + gs._features_dc = torch.tensor( + state["gaussian"]["_features_dc"], device=device + ) + gs._scaling = torch.tensor(state["gaussian"]["_scaling"], device=device) + gs._rotation = torch.tensor(state["gaussian"]["_rotation"], device=device) + gs._opacity = torch.tensor(state["gaussian"]["_opacity"], device=device) + + mesh = edict( + vertices=torch.tensor(state["mesh"]["vertices"], device=device), + faces=torch.tensor(state["mesh"]["faces"], device=device), + ) + + return gs, mesh diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py index 6e77449..9feedde 100644 --- a/embodied_gen/validators/aesthetic_predictor.py +++ b/embodied_gen/validators/aesthetic_predictor.py @@ -125,7 +125,11 @@ def predict(self, image_path): Returns: float: Predicted aesthetic score. """ - pil_image = Image.open(image_path) + if isinstance(image_path, str): + pil_image = Image.open(image_path) + else: + pil_image = image_path + image = self.preprocess(pil_image).unsqueeze(0).to(self.device) with torch.no_grad(): diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py index 0e5ff7e..3cec795 100644 --- a/embodied_gen/validators/quality_checkers.py +++ b/embodied_gen/validators/quality_checkers.py @@ -126,6 +126,30 @@ def __init__( super().__init__(prompt, verbose) self.gpt_client = gpt_client if self.prompt is None: + # Old version for TRELLIS. + # self.prompt = """ + # You are an expert in evaluating the geometry quality of generated 3D asset. + # You will be given rendered views of a generated 3D asset, type {}, with black background. + # Your task is to evaluate the quality of the 3D asset generation, + # including geometry, structure, and appearance, based on the rendered views. + # Criteria: + # - Is the object in the image a single, complete, and well-formed instance, + # without truncation, missing parts, overlapping duplicates, or redundant geometry? + # - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back, + # soft edges) are acceptable if the object is structurally sound and recognizable. + # - Only evaluate geometry. Do not assess texture quality. + # - The asset should not contain any unrelated elements, such as + # ground planes, platforms, or background props (e.g., paper, flooring). + + # If all the above criteria are met, return "YES". Otherwise, return + # "NO" followed by a brief explanation (no more than 20 words). + + # Example: + # Images show a yellow cup standing on a flat white plane -> NO + # -> Response: NO: extra white surface under the object. + # Image shows a chair with simplified back legs and soft edges -> YES + # """ + self.prompt = """ You are an expert in evaluating the geometry quality of generated 3D asset. You will be given rendered views of a generated 3D asset, type {}, with black background. @@ -137,16 +161,13 @@ def __init__( - Minor flaws, asymmetries, or simplifications (e.g., less detail on sides or back, soft edges) are acceptable if the object is structurally sound and recognizable. - Only evaluate geometry. Do not assess texture quality. - - The asset should not contain any unrelated elements, such as - ground planes, platforms, or background props (e.g., paper, flooring). - If all the above criteria are met, return "YES". Otherwise, return + If all the above criteria are met, return "YES" only. Otherwise, return "NO" followed by a brief explanation (no more than 20 words). Example: - Images show a yellow cup standing on a flat white plane -> NO - -> Response: NO: extra white surface under the object. - Image shows a chair with simplified back legs and soft edges → YES + Image shows a chair with one leg missing -> NO: the chair missing leg. + Image shows a geometrically complete cup -> YES """ def query( diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py index 3f070be..8a48e94 100644 --- a/embodied_gen/validators/urdf_convertor.py +++ b/embodied_gen/validators/urdf_convertor.py @@ -27,7 +27,10 @@ from scipy.spatial.transform import Rotation from embodied_gen.data.convex_decomposer import decompose_convex_mesh from embodied_gen.utils.gpt_clients import GPT_CLIENT, GPTclient -from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.utils.process_media import ( + combine_images_to_grid, + render_asset3d, +) from embodied_gen.utils.tags import VERSION logging.basicConfig(level=logging.INFO) @@ -482,7 +485,7 @@ def __call__( output_subdir=self.output_render_dir, no_index_file=True, ) - + # image_path = combine_images_to_grid(image_path) response = self.gpt_client.query(text_prompt, image_path) # logger.info(response) if response is None: diff --git a/install/install_basic.sh b/install/install_basic.sh index 63d4af4..ccbf861 100644 --- a/install/install_basic.sh +++ b/install/install_basic.sh @@ -8,7 +8,7 @@ PIP_INSTALL_PACKAGES=( "torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu118" "xformers==0.0.27.post2 --index-url https://download.pytorch.org/whl/cu118" "-r requirements.txt --use-deprecated=legacy-resolver" - "flash-attn==2.7.0.post2" + # "flash-attn==2.7.0.post2" "utils3d@git+https://github.com/EasternJournalist/utils3d.git@9a4eb15" "clip@git+https://github.com/openai/CLIP.git" "segment-anything@git+https://github.com/facebookresearch/segment-anything.git@dca509f" @@ -16,6 +16,8 @@ PIP_INSTALL_PACKAGES=( "kolors@git+https://github.com/HochCC/Kolors.git" "kaolin@git+https://github.com/NVIDIAGameWorks/kaolin.git@v0.16.0" "git+https://github.com/nerfstudio-project/gsplat.git@v1.5.3" + "git+https://github.com/facebookresearch/pytorch3d.git@stable" + "MoGe@git+https://github.com/microsoft/MoGe.git@a8c3734" ) for pkg in "${PIP_INSTALL_PACKAGES[@]}"; do diff --git a/install/install_extra.sh b/install/install_extra.sh index af63f7b..302e0ad 100644 --- a/install/install_extra.sh +++ b/install/install_extra.sh @@ -4,21 +4,17 @@ SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) source "$SCRIPT_DIR/_utils.sh" PYTHON_PACKAGES_NODEPS=( - "timm" "txt2panoimg@git+https://github.com/HochCC/SD-T2I-360PanoImage" ) PYTHON_PACKAGES=( - "ninja" - "fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98" + "fused-ssim@git+https://github.com/rahul-goel/fused-ssim#egg=328dc98 --no-build-isolation" "git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch" - "git+https://github.com/facebookresearch/pytorch3d.git@stable" "kornia" "h5py" "albumentations==0.5.2" "webdataset" "icecream" - "open3d" "pyequilib" ) diff --git a/pyproject.toml b/pyproject.toml index be3cc60..8ed7bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["embodied_gen"] [project] name = "embodied_gen" -version = "v0.1.6" +version = "v0.1.7" readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE", "NOTICE"] diff --git a/requirements.txt b/requirements.txt index 05fbfc3..c72fa6b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,4 +42,13 @@ coacd mani_skill==3.0.0b21 typing_extensions==4.14.1 ninja -packaging \ No newline at end of file +packaging +lightning +astor +optree +loguru +seaborn +hydra-core +modelscope +timm +open3d \ No newline at end of file diff --git a/tests/test_examples/test_quality_checkers.py b/tests/test_examples/test_quality_checkers.py index 8b604d6..207c415 100644 --- a/tests/test_examples/test_quality_checkers.py +++ b/tests/test_examples/test_quality_checkers.py @@ -21,7 +21,10 @@ import pytest from embodied_gen.utils.gpt_clients import GPT_CLIENT -from embodied_gen.utils.process_media import render_asset3d +from embodied_gen.utils.process_media import ( + combine_images_to_grid, + render_asset3d, +) from embodied_gen.validators.quality_checkers import ( ImageAestheticChecker, ImageSegChecker, @@ -166,12 +169,13 @@ def test_textgen_checker(textalign_checker, mesh_path, text_desc): image_list = render_asset3d( mesh_path, output_root=output_root, - num_images=6, + num_images=4, elevation=(30, -30), output_subdir="renders", no_index_file=True, with_mtl=False, ) + image_list = combine_images_to_grid(image_list) flag, result = textalign_checker(text_desc, image_list) logger.info(f"textalign_checker: {flag}, {result}")