From af6bc680a483c8f52e6ad1eecea38efe71e219df Mon Sep 17 00:00:00 2001 From: "xinjie.wang" Date: Thu, 6 Nov 2025 17:49:05 +0800 Subject: [PATCH 1/3] update --- docs/js/model_viewer.js | 34 ++++++++++++++++++++++ docs/services/image_to_3d.md | 32 -------------------- docs/services/text_to_3d.md | 32 -------------------- docs/services/texture_edit.md | 32 -------------------- docs/tutorials/any_simulators.md | 2 +- docs/tutorials/digital_twin.md | 2 +- docs/tutorials/gym_env.md | 4 +-- docs/tutorials/image_to_3d.md | 31 -------------------- docs/tutorials/index.md | 50 ++++++-------------------------- docs/tutorials/layout_gen.md | 19 +++++++----- docs/tutorials/scene_gen.md | 2 +- docs/tutorials/text_to_3d.md | 32 -------------------- docs/tutorials/texture_edit.md | 32 -------------------- mkdocs.yml | 8 +++++ 14 files changed, 67 insertions(+), 245 deletions(-) create mode 100644 docs/js/model_viewer.js diff --git a/docs/js/model_viewer.js b/docs/js/model_viewer.js new file mode 100644 index 0000000..12b86d8 --- /dev/null +++ b/docs/js/model_viewer.js @@ -0,0 +1,34 @@ +document.addEventListener('DOMContentLoaded', function () { + + const swiperElement = document.querySelector('.swiper1'); + + if (swiperElement) { + const swiper = new Swiper('.swiper1', { + loop: true, + slidesPerView: 3, + spaceBetween: 20, + navigation: { + nextEl: '.swiper-button-next', + prevEl: '.swiper-button-prev', + }, + centeredSlides: false, + noSwiping: true, + noSwipingClass: 'swiper-no-swiping', + watchSlidesProgress: true, + }); + + const modelViewers = swiperElement.querySelectorAll('model-viewer'); + + if (modelViewers.length > 0) { + let loadedCount = 0; + modelViewers.forEach(mv => { + mv.addEventListener('load', () => { + loadedCount++; + if (loadedCount === modelViewers.length) { + swiper.update(); + } + }); + }); + } + } +}); \ No newline at end of file diff --git a/docs/services/image_to_3d.md b/docs/services/image_to_3d.md index 10da09b..4e1a1c8 100644 --- a/docs/services/image_to_3d.md +++ b/docs/services/image_to_3d.md @@ -1,35 +1,3 @@ - - - - - - # πŸ–ΌοΈ Image-to-3D Service [![πŸ€— Hugging Face](https://img.shields.io/badge/πŸ€—-Image_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Image-to-3D) diff --git a/docs/services/text_to_3d.md b/docs/services/text_to_3d.md index a86bd31..4d590f5 100644 --- a/docs/services/text_to_3d.md +++ b/docs/services/text_to_3d.md @@ -1,35 +1,3 @@ - - - - - - # πŸ“ Text-to-3D Service [![πŸ€— Hugging Face](https://img.shields.io/badge/πŸ€—-Text_to_3D_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Text-to-3D) diff --git a/docs/services/texture_edit.md b/docs/services/texture_edit.md index 4898044..8a1c935 100644 --- a/docs/services/texture_edit.md +++ b/docs/services/texture_edit.md @@ -1,35 +1,3 @@ - - - - - - # 🎨 Texture Generation Service [![πŸ€— Hugging Face](https://img.shields.io/badge/πŸ€—-Texture_Gen_Demo-blue)](https://huggingface.co/spaces/HorizonRobotics/EmbodiedGen-Texture-Gen) diff --git a/docs/tutorials/any_simulators.md b/docs/tutorials/any_simulators.md index f8182b7..3c2d83e 100644 --- a/docs/tutorials/any_simulators.md +++ b/docs/tutorials/any_simulators.md @@ -58,6 +58,6 @@ dst_asset_path = cvt_embodiedgen_asset_to_anysim( ) ``` -simulators_collision +simulators_collision Collision and visualization mesh across simulators, showing consistent geometry and material fidelity. diff --git a/docs/tutorials/digital_twin.md b/docs/tutorials/digital_twin.md index e9fef46..85a7315 100644 --- a/docs/tutorials/digital_twin.md +++ b/docs/tutorials/digital_twin.md @@ -1,3 +1,3 @@ # Real-to-Sim Digital Twin Creation -real2sim_mujoco +real2sim_mujoco diff --git a/docs/tutorials/gym_env.md b/docs/tutorials/gym_env.md index 885917e..8696966 100644 --- a/docs/tutorials/gym_env.md +++ b/docs/tutorials/gym_env.md @@ -14,9 +14,9 @@ python embodied_gen/scripts/parallel_sim.py \ ```
- parallel_sim1 - parallel_sim2
diff --git a/docs/tutorials/image_to_3d.md b/docs/tutorials/image_to_3d.md index d63dae1..2f7178a 100644 --- a/docs/tutorials/image_to_3d.md +++ b/docs/tutorials/image_to_3d.md @@ -1,34 +1,3 @@ - - - - - # πŸ–ΌοΈ Image-to-3D: Physically Plausible 3D Asset Generation Generate **physically plausible 3D assets** from a single input image, supporting **digital twin** and **simulation environments**. diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md index 014e67d..b1569d6 100644 --- a/docs/tutorials/index.md +++ b/docs/tutorials/index.md @@ -1,35 +1,3 @@ - - - - - - # Tutorials & Interface Usage Welcome to the tutorials for `EmbodiedGen`. `EmbodiedGen` is a powerful toolset for generating 3D assets, textures, scenes, and interactive layouts ready for simulators and digital twin environments. @@ -161,7 +129,7 @@ Generate **high-quality textures** for 3D meshes using **text prompts**, support Generate **physically consistent and visually coherent 3D environments** from text prompts. Typically used as **background** 3DGS scenes in simulators for efficient and photo-realistic rendering. - + --- @@ -170,10 +138,10 @@ Generate **physically consistent and visually coherent 3D environments** from te Generate diverse, physically realistic, and scalable **interactive 3D scenes** from natural language task descriptions, while also modeling the robot and manipulable objects.
- layout1 - layout2 - layout3 - layout4 + layout1 + layout2 + layout3 + layout4
@@ -184,9 +152,9 @@ Generate diverse, physically realistic, and scalable **interactive 3D scenes** f Generate multiple **parallel simulation environments** with `gym.make` and record sensor and trajectory data.
- parallel_sim1 - parallel_sim2
@@ -198,11 +166,11 @@ Generate multiple **parallel simulation environments** with `gym.make` and recor Seamlessly use EmbodiedGen-generated assets in major simulators like **IsaacSim**, **MuJoCo**, **Genesis**, **PyBullet**, **IsaacGym**, and **SAPIEN**, featuring **accurate physical collisions** and **consistent visual appearance**.
- simulators_collision + simulators_collision
## [πŸ”§ Real-to-Sim Digital Twin Creation](digital_twin.md)
- real2sim_mujoco + real2sim_mujoco
diff --git a/docs/tutorials/layout_gen.md b/docs/tutorials/layout_gen.md index 7af14b4..3109ff1 100644 --- a/docs/tutorials/layout_gen.md +++ b/docs/tutorials/layout_gen.md @@ -3,10 +3,10 @@ Layout Generation enables the generation of diverse, physically realistic, and scalable **interactive 3D scenes** directly from natural language task descriptions, while also modeling the robot's pose and relationships with manipulable objects. Target objects are randomly placed within the robot's reachable range, making the scenes readily usable for downstream simulation and reinforcement learning tasks in any mainstream simulator.
- layout1 - layout2 - layout3 - layout4 + layout1 + layout2 + layout3 + layout4
!!! note "Model Requirement" @@ -26,7 +26,7 @@ Each scene takes approximately **30 minutes** to generate. For efficiency, we re hf download xinjjj/scene3d-bg --repo-type dataset --local-dir outputs # Option 2: Download a larger background set (~14 GB) -hf download xinjjj/EmbodiedGenRLv2-BG --repo-type dataset --local-dir outputs +hf download xinjjj..RLv2-BG --repo-type dataset --local-dir outputs ``` ## Generate Interactive Layout Scenes @@ -43,12 +43,15 @@ layout-cli \ ``` You will get the following results: -
- Iscene_demo1 - Iscene_demo2 +
+ Iscene_demo1 + Iscene_demo2
+ ### Batch Generation You can also run multiple tasks via a task list file in the backend. diff --git a/docs/tutorials/scene_gen.md b/docs/tutorials/scene_gen.md index b7029a2..15c5a14 100644 --- a/docs/tutorials/scene_gen.md +++ b/docs/tutorials/scene_gen.md @@ -4,7 +4,7 @@ Generate **physically consistent and visually coherent 3D environments** from te --- - + --- diff --git a/docs/tutorials/text_to_3d.md b/docs/tutorials/text_to_3d.md index bef2f24..3a81366 100644 --- a/docs/tutorials/text_to_3d.md +++ b/docs/tutorials/text_to_3d.md @@ -1,35 +1,3 @@ - - - - - - # πŸ“ Text-to-3D: Generate 3D Assets from Text Create **physically plausible 3D assets** from **text descriptions**, supporting a wide range of geometry, style, and material details. diff --git a/docs/tutorials/texture_edit.md b/docs/tutorials/texture_edit.md index 8a03ec7..480d81c 100644 --- a/docs/tutorials/texture_edit.md +++ b/docs/tutorials/texture_edit.md @@ -1,35 +1,3 @@ - - - - - - # 🎨 Texture Generation: Create Visually Rich Textures for 3D Meshes Generate **high-quality textures** for 3D meshes using **text prompts**, supporting both **Chinese and English**. This allows you to enhance the visual appearance of existing 3D assets for simulation, visualization, or digital twin applications. diff --git a/mkdocs.yml b/mkdocs.yml index a280926..d55214c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -3,6 +3,7 @@ site_url: https://horizonrobotics.github.io/EmbodiedGen/ repo_name: "EmbodiedGen" repo_url: https://github.com/HorizonRobotics/EmbodiedGen copyright: "Copyright (c) 2025 Horizon Robotics" +use_directory_urls: false nav: - 🏠 Home: index.md @@ -102,6 +103,13 @@ plugins: extra_css: - stylesheets/extra.css + - https://cdn.jsdelivr.net/npm/swiper/swiper-bundle.min.css + +extra_javascript: + - https://cdn.jsdelivr.net/npm/swiper/swiper-bundle.min.js + - path: https://unpkg.com/@google/model-viewer/dist/model-viewer.min.js + type: module + - js/model_viewer.js markdown_extensions: - pymdownx.highlight From dc005ecd95b82b8138f31085f231f5b1a605c06d Mon Sep 17 00:00:00 2001 From: "xinjie.wang" Date: Thu, 6 Nov 2025 19:22:21 +0800 Subject: [PATCH 2/3] update --- README.md | 1 + docs/index.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 3e4b4c4..7e3fddb 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # *EmbodiedGen*: Towards a Generative 3D World Engine for Embodied Intelligence [![πŸ“– Documentation](https://img.shields.io/badge/πŸ“–-Documentation-blue)](https://horizonrobotics.github.io/EmbodiedGen/) +[![GitHub](https://img.shields.io/badge/GitHub-EmbodiedGen-black?logo=github)](https://github.com/HorizonRobotics/EmbodiedGen) [![πŸ“„ arXiv](https://img.shields.io/badge/πŸ“„-arXiv-b31b1b)](https://arxiv.org/abs/2506.10600) [![πŸŽ₯ Video](https://img.shields.io/badge/πŸŽ₯-Video-red)](https://www.youtube.com/watch?v=rG4odybuJRk) [![中文介绍](https://img.shields.io/badge/中文介绍-07C160?logo=wechat&logoColor=white)](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw) diff --git a/docs/index.md b/docs/index.md index e191f96..2a777c6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,6 +6,7 @@ hide: # πŸ‘‹ Welcome to EmbodiedGen [![πŸ“– Documentation](https://img.shields.io/badge/πŸ“–-Documentation-blue)](https://horizonrobotics.github.io/EmbodiedGen/) +[![GitHub](https://img.shields.io/badge/GitHub-EmbodiedGen-black?logo=github)](https://github.com/HorizonRobotics/EmbodiedGen) [![πŸ“„ arXiv](https://img.shields.io/badge/πŸ“„-arXiv-b31b1b)](https://arxiv.org/abs/2506.10600) [![πŸŽ₯ Video](https://img.shields.io/badge/πŸŽ₯-Video-red)](https://www.youtube.com/watch?v=rG4odybuJRk) [![中文介绍](https://img.shields.io/badge/中文介绍-07C160?logo=wechat&logoColor=white)](https://mp.weixin.qq.com/s/HH1cPBhK2xcDbyCK4BBTbw) From 4ec6d1ee8b53610f7f4ee3131c5205a7d967fb0e Mon Sep 17 00:00:00 2001 From: "xinjie.wang" Date: Fri, 7 Nov 2025 19:09:15 +0800 Subject: [PATCH 3/3] update --- README.md | 2 +- apps/visualize_asset.py | 3 +- docs/install.md | 2 +- docs/tutorials/any_simulators.md | 7 +- embodied_gen/data/asset_converter.py | 338 +++++++++++++----- embodied_gen/data/backproject_v2.py | 200 +++++++++-- embodied_gen/data/convex_decomposer.py | 79 +++- embodied_gen/data/differentiable_render.py | 97 +++-- embodied_gen/data/mesh_operator.py | 4 - embodied_gen/envs/pick_embodiedgen.py | 195 ++++++++++ embodied_gen/models/delight_model.py | 2 +- embodied_gen/models/image_comm_model.py | 138 +++++++ embodied_gen/models/layout.py | 82 +++++ embodied_gen/models/segment_model.py | 145 +++++++- embodied_gen/models/sr_model.py | 63 +++- embodied_gen/models/text_model.py | 59 +++ embodied_gen/trainer/pono2mesh_trainer.py | 149 +++++++- embodied_gen/utils/enum.py | 137 +++++++ embodied_gen/utils/geometry.py | 156 ++++++-- embodied_gen/utils/gpt_clients.py | 52 ++- embodied_gen/utils/process_media.py | 171 ++++++++- embodied_gen/utils/simulation.py | 165 +++++++-- embodied_gen/utils/tags.py | 2 +- .../validators/aesthetic_predictor.py | 20 +- embodied_gen/validators/quality_checkers.py | 71 +++- embodied_gen/validators/urdf_convertor.py | 78 +++- pyproject.toml | 2 +- tests/test_examples/test_asset_converter.py | 8 +- 28 files changed, 2160 insertions(+), 267 deletions(-) diff --git a/README.md b/README.md index 7e3fddb..220db79 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.5 +git checkout v0.1.6 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen diff --git a/apps/visualize_asset.py b/apps/visualize_asset.py index 089e329..85e12dd 100644 --- a/apps/visualize_asset.py +++ b/apps/visualize_asset.py @@ -31,8 +31,8 @@ import gradio as gr import pandas as pd -import yaml from app_style import custom_theme, lighting_css +from embodied_gen.utils.tags import VERSION try: from embodied_gen.utils.gpt_clients import GPT_CLIENT as gpt_client @@ -48,7 +48,6 @@ # --- Configuration & Data Loading --- -VERSION = "v0.1.5" RUNNING_MODE = "local" # local or hf_remote CSV_FILE = "dataset_index.csv" diff --git a/docs/install.md b/docs/install.md index 56d200f..8262eba 100644 --- a/docs/install.md +++ b/docs/install.md @@ -7,7 +7,7 @@ hide: ```sh git clone https://github.com/HorizonRobotics/EmbodiedGen.git cd EmbodiedGen -git checkout v0.1.5 +git checkout v0.1.6 git submodule update --init --recursive --progress conda create -n embodiedgen python=3.10.13 -y # recommended to use a new env. conda activate embodiedgen diff --git a/docs/tutorials/any_simulators.md b/docs/tutorials/any_simulators.md index 3c2d83e..0d6e146 100644 --- a/docs/tutorials/any_simulators.md +++ b/docs/tutorials/any_simulators.md @@ -35,7 +35,8 @@ Leverage **EmbodiedGen-generated assets** with *accurate physical collisions* an ## 🧱 Example: Conversion to Target Simulator ```python -from embodied_gen.data.asset_converter import SimAssetMapper, cvt_embodiedgen_asset_to_anysim +from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim +from embodied_gen.utils.enum import AssetType, SimAssetMapper from typing import Literal simulator_name: Literal[ @@ -52,6 +53,10 @@ dst_asset_path = cvt_embodiedgen_asset_to_anysim( "path1_to_embodiedgen_asset/asset.urdf", "path2_to_embodiedgen_asset/asset.urdf", ], + target_dirs=[ + "path1_to_target_dir/asset.usd", + "path2_to_target_dir/asset.usd", + ], target_type=SimAssetMapper[simulator_name], source_type=AssetType.MESH, overwrite=True, diff --git a/embodied_gen/data/asset_converter.py b/embodied_gen/data/asset_converter.py index 3b32c93..f4e1ac6 100644 --- a/embodied_gen/data/asset_converter.py +++ b/embodied_gen/data/asset_converter.py @@ -4,12 +4,12 @@ import os import xml.etree.ElementTree as ET from abc import ABC, abstractmethod -from dataclasses import dataclass from glob import glob from shutil import copy, copytree, rmtree import trimesh from scipy.spatial.transform import Rotation +from embodied_gen.utils.enum import AssetType logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -17,75 +17,62 @@ __all__ = [ "AssetConverterFactory", - "AssetType", "MeshtoMJCFConverter", "MeshtoUSDConverter", "URDFtoUSDConverter", "cvt_embodiedgen_asset_to_anysim", "PhysicsUSDAdder", - "SimAssetMapper", ] -@dataclass -class AssetType(str): - """Asset type enumeration.""" - - MJCF = "mjcf" - USD = "usd" - URDF = "urdf" - MESH = "mesh" - - -class SimAssetMapper: - _mapping = dict( - ISAACSIM=AssetType.USD, - ISAACGYM=AssetType.URDF, - MUJOCO=AssetType.MJCF, - GENESIS=AssetType.MJCF, - SAPIEN=AssetType.URDF, - PYBULLET=AssetType.URDF, - ) - - @classmethod - def __class_getitem__(cls, key: str): - key = key.upper() - if key.startswith("SAPIEN"): - key = "SAPIEN" - return cls._mapping[key] - - def cvt_embodiedgen_asset_to_anysim( urdf_files: list[str], + target_dirs: list[str], target_type: AssetType, source_type: AssetType, overwrite: bool = False, **kwargs, ) -> dict[str, str]: - """Convert URDF files generated by EmbodiedGen into the format required by all simulators. + """Convert URDF files generated by EmbodiedGen into formats required by simulators. Supported simulators include SAPIEN, Isaac Sim, MuJoCo, Isaac Gym, Genesis, and Pybullet. + Converting to the `USD` format requires `isaacsim` to be installed. Example: + ```py + from embodied_gen.data.asset_converter import cvt_embodiedgen_asset_to_anysim + from embodied_gen.utils.enum import AssetType + dst_asset_path = cvt_embodiedgen_asset_to_anysim( - urdf_files, - target_type=SimAssetMapper[simulator_name], + urdf_files=[ + "path1_to_embodiedgen_asset/asset.urdf", + "path2_to_embodiedgen_asset/asset.urdf", + ], + target_dirs=[ + "path1_to_target_dir/asset.usd", + "path2_to_target_dir/asset.usd", + ], + target_type=AssetType.USD, source_type=AssetType.MESH, ) + ``` Args: - urdf_files (List[str]): List of URDF file paths to be converted. - target_type (AssetType): The target asset type. - source_type (AssetType): The source asset type. - overwrite (bool): Whether to overwrite existing converted files. - **kwargs: Additional keyword arguments for the converter. + urdf_files (list[str]): List of URDF file paths. + target_dirs (list[str]): List of target directories. + target_type (AssetType): Target asset type. + source_type (AssetType): Source asset type. + overwrite (bool, optional): Overwrite existing files. + **kwargs: Additional converter arguments. Returns: - Dict[str, str]: A dictionary mapping the original URDF file path to the converted asset file path. + dict[str, str]: Mapping from URDF file to converted asset file. """ if isinstance(urdf_files, str): urdf_files = [urdf_files] + if isinstance(target_dirs, str): + urdf_files = [target_dirs] # If the target type is URDF, no conversion is needed. if target_type == AssetType.URDF: @@ -99,18 +86,17 @@ def cvt_embodiedgen_asset_to_anysim( asset_paths = dict() with asset_converter: - for urdf_file in urdf_files: + for urdf_file, target_dir in zip(urdf_files, target_dirs): filename = os.path.basename(urdf_file).replace(".urdf", "") - asset_dir = os.path.dirname(urdf_file) if target_type == AssetType.MJCF: - target_file = f"{asset_dir}/../mjcf/{filename}.xml" + target_file = f"{target_dir}/{filename}.xml" elif target_type == AssetType.USD: - target_file = f"{asset_dir}/../usd/{filename}.usd" + target_file = f"{target_dir}/{filename}.usd" else: raise NotImplementedError( f"Target type {target_type} not supported." ) - if not os.path.exists(target_file): + if not os.path.exists(target_file) or overwrite: asset_converter.convert(urdf_file, target_file) asset_paths[urdf_file] = target_file @@ -119,16 +105,35 @@ def cvt_embodiedgen_asset_to_anysim( class AssetConverterBase(ABC): - """Converter abstract base class.""" + """Abstract base class for asset converters. + + Provides context management and mesh transformation utilities. + """ @abstractmethod def convert(self, urdf_path: str, output_path: str, **kwargs) -> str: + """Convert an asset file. + + Args: + urdf_path (str): Path to input URDF file. + output_path (str): Path to output file. + **kwargs: Additional arguments. + + Returns: + str: Path to converted asset. + """ pass def transform_mesh( self, input_mesh: str, output_mesh: str, mesh_origin: ET.Element ) -> None: - """Apply transform to the mesh based on the origin element in URDF.""" + """Apply transform to mesh based on URDF origin element. + + Args: + input_mesh (str): Path to input mesh. + output_mesh (str): Path to output mesh. + mesh_origin (ET.Element): Origin element from URDF. + """ mesh = trimesh.load(input_mesh, group_material=False) rpy = list(map(float, mesh_origin.get("rpy").split(" "))) rotation = Rotation.from_euler("xyz", rpy, degrees=False) @@ -150,14 +155,19 @@ def transform_mesh( return def __enter__(self): + """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" return False class MeshtoMJCFConverter(AssetConverterBase): - """Convert URDF files into MJCF format.""" + """Converts mesh-based URDF files to MJCF format. + + Handles geometry, materials, and asset copying. + """ def __init__( self, @@ -166,6 +176,12 @@ def __init__( self.kwargs = kwargs def _copy_asset_file(self, src: str, dst: str) -> None: + """Copies asset file if not already present. + + Args: + src (str): Source file path. + dst (str): Destination file path. + """ if os.path.exists(dst): return os.makedirs(os.path.dirname(dst), exist_ok=True) @@ -183,7 +199,19 @@ def add_geometry( material: ET.Element | None = None, is_collision: bool = False, ) -> None: - """Add geometry to the MJCF body from the URDF link.""" + """Adds geometry to MJCF body from URDF link. + + Args: + mujoco_element (ET.Element): MJCF asset element. + link (ET.Element): URDF link element. + body (ET.Element): MJCF body element. + tag (str): Tag name ("visual" or "collision"). + input_dir (str): Input directory. + output_dir (str): Output directory. + mesh_name (str): Mesh name. + material (ET.Element, optional): Material element. + is_collision (bool, optional): If True, treat as collision geometry. + """ element = link.find(tag) geometry = element.find("geometry") mesh = geometry.find("mesh") @@ -242,7 +270,20 @@ def add_materials( name: str, reflectance: float = 0.2, ) -> ET.Element: - """Add materials to the MJCF asset from the URDF link.""" + """Adds materials to MJCF asset from URDF link. + + Args: + mujoco_element (ET.Element): MJCF asset element. + link (ET.Element): URDF link element. + tag (str): Tag name. + input_dir (str): Input directory. + output_dir (str): Output directory. + name (str): Material name. + reflectance (float, optional): Reflectance value. + + Returns: + ET.Element: Material element. + """ element = link.find(tag) geometry = element.find("geometry") mesh = geometry.find("mesh") @@ -282,7 +323,12 @@ def add_materials( return material def convert(self, urdf_path: str, mjcf_path: str): - """Convert a URDF file to MJCF format.""" + """Converts a URDF file to MJCF format. + + Args: + urdf_path (str): Path to URDF file. + mjcf_path (str): Path to output MJCF file. + """ tree = ET.parse(urdf_path) root = tree.getroot() @@ -336,10 +382,22 @@ def convert(self, urdf_path: str, mjcf_path: str): class URDFtoMJCFConverter(MeshtoMJCFConverter): - """Convert URDF files with joints to MJCF format, handling transformations from joints.""" + """Converts URDF files with joints to MJCF format, handling joint transformations. + + Handles fixed joints and hierarchical body structure. + """ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str: - """Convert a URDF file with joints to MJCF format.""" + """Converts a URDF file with joints to MJCF format. + + Args: + urdf_path (str): Path to URDF file. + mjcf_path (str): Path to output MJCF file. + **kwargs: Additional arguments. + + Returns: + str: Path to converted MJCF file. + """ tree = ET.parse(urdf_path) root = tree.getroot() @@ -423,7 +481,10 @@ def convert(self, urdf_path: str, mjcf_path: str, **kwargs) -> str: class MeshtoUSDConverter(AssetConverterBase): - """Convert Mesh file from URDF into USD format.""" + """Converts mesh-based URDF files to USD format. + + Adds physics APIs and post-processes collision meshes. + """ DEFAULT_BIND_APIS = [ "MaterialBindingAPI", @@ -443,6 +504,14 @@ def __init__( simulation_app=None, **kwargs, ): + """Initializes the converter. + + Args: + force_usd_conversion (bool, optional): Force USD conversion. + make_instanceable (bool, optional): Make prims instanceable. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. + """ if simulation_app is not None: self.simulation_app = simulation_app @@ -458,6 +527,7 @@ def __init__( ) def __enter__(self): + """Context manager entry, launches simulation app if needed.""" from isaaclab.app import AppLauncher if not hasattr(self, "simulation_app"): @@ -476,6 +546,7 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit, closes simulation app if created.""" # Close the simulation app if it was created here if hasattr(self, "app_launcher") and self.exit_close: self.simulation_app.close() @@ -486,7 +557,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False def convert(self, urdf_path: str, output_file: str): - """Convert a URDF file to USD and post-process collision meshes.""" + """Converts a URDF file to USD and post-processes collision meshes. + + Args: + urdf_path (str): Path to URDF file. + output_file (str): Path to output USD file. + """ from isaaclab.sim.converters import MeshConverter, MeshConverterCfg from pxr import PhysxSchema, Sdf, Usd, UsdShade @@ -556,6 +632,11 @@ def convert(self, urdf_path: str, output_file: str): class PhysicsUSDAdder(MeshtoUSDConverter): + """Adds physics APIs and collision properties to USD assets. + + Useful for post-processing USD files for simulation. + """ + DEFAULT_BIND_APIS = [ "MaterialBindingAPI", # "PhysicsMeshCollisionAPI", @@ -566,6 +647,12 @@ class PhysicsUSDAdder(MeshtoUSDConverter): ] def convert(self, usd_path: str, output_file: str = None): + """Adds physics APIs and collision properties to a USD file. + + Args: + usd_path (str): Path to input USD file. + output_file (str, optional): Path to output USD file. + """ from pxr import PhysxSchema, Sdf, Usd, UsdGeom, UsdPhysics if output_file is None: @@ -626,14 +713,18 @@ def convert(self, usd_path: str, output_file: str = None): class URDFtoUSDConverter(MeshtoUSDConverter): - """Convert URDF files into USD format. + """Converts URDF files to USD format. Args: - fix_base (bool): Whether to fix the base link. - merge_fixed_joints (bool): Whether to merge fixed joints. - make_instanceable (bool): Whether to make prims instanceable. - force_usd_conversion (bool): Force conversion to USD. - collision_from_visuals (bool): Generate collisions from visuals if not provided. + fix_base (bool, optional): Fix the base link. + merge_fixed_joints (bool, optional): Merge fixed joints. + make_instanceable (bool, optional): Make prims instanceable. + force_usd_conversion (bool, optional): Force conversion to USD. + collision_from_visuals (bool, optional): Generate collisions from visuals. + joint_drive (optional): Joint drive configuration. + rotate_wxyz (tuple[float], optional): Quaternion for rotation. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. """ def __init__( @@ -648,6 +739,19 @@ def __init__( simulation_app=None, **kwargs, ): + """Initializes the converter. + + Args: + fix_base (bool, optional): Fix the base link. + merge_fixed_joints (bool, optional): Merge fixed joints. + make_instanceable (bool, optional): Make prims instanceable. + force_usd_conversion (bool, optional): Force conversion to USD. + collision_from_visuals (bool, optional): Generate collisions from visuals. + joint_drive (optional): Joint drive configuration. + rotate_wxyz (tuple[float], optional): Quaternion for rotation. + simulation_app (optional): Simulation app instance. + **kwargs: Additional arguments. + """ self.usd_parms = dict( fix_base=fix_base, merge_fixed_joints=merge_fixed_joints, @@ -662,7 +766,12 @@ def __init__( self.simulation_app = simulation_app def convert(self, urdf_path: str, output_file: str): - """Convert a URDF file to USD and post-process collision meshes.""" + """Converts a URDF file to USD and post-processes collision meshes. + + Args: + urdf_path (str): Path to URDF file. + output_file (str): Path to output USD file. + """ from isaaclab.sim.converters import UrdfConverter, UrdfConverterCfg from pxr import Gf, PhysxSchema, Sdf, Usd, UsdGeom @@ -723,13 +832,36 @@ def convert(self, urdf_path: str, output_file: str): class AssetConverterFactory: - """Factory class for creating asset converters based on target and source types.""" + """Factory for creating asset converters based on target and source types. + + Example: + ```py + from embodied_gen.data.asset_converter import AssetConverterFactory + from embodied_gen.utils.enum import AssetType + + converter = AssetConverterFactory.create( + target_type=AssetType.USD, source_type=AssetType.MESH + ) + with converter: + for urdf_path, output_file in zip(urdf_paths, output_files): + converter.convert(urdf_path, output_file) + ``` + """ @staticmethod def create( target_type: AssetType, source_type: AssetType = "urdf", **kwargs ) -> AssetConverterBase: - """Create an asset converter instance based on target and source types.""" + """Creates an asset converter instance. + + Args: + target_type (AssetType): Target asset type. + source_type (AssetType, optional): Source asset type. + **kwargs: Additional arguments. + + Returns: + AssetConverterBase: Converter instance. + """ if target_type == AssetType.MJCF and source_type == AssetType.MESH: converter = MeshtoMJCFConverter(**kwargs) elif target_type == AssetType.MJCF and source_type == AssetType.URDF: @@ -751,7 +883,14 @@ def create( # target_asset_type = AssetType.USD urdf_paths = [ - "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf", + 'outputs/EmbodiedGenData/demo_assets/banana/result/banana.urdf', + 'outputs/EmbodiedGenData/demo_assets/book/result/book.urdf', + 'outputs/EmbodiedGenData/demo_assets/lamp/result/lamp.urdf', + 'outputs/EmbodiedGenData/demo_assets/mug/result/mug.urdf', + 'outputs/EmbodiedGenData/demo_assets/remote_control/result/remote_control.urdf', + "outputs/EmbodiedGenData/demo_assets/rubik's_cube/result/rubik's_cube.urdf", + 'outputs/EmbodiedGenData/demo_assets/table/result/table.urdf', + 'outputs/EmbodiedGenData/demo_assets/vase/result/vase.urdf', ] if target_asset_type == AssetType.MJCF: @@ -765,7 +904,14 @@ def create( elif target_asset_type == AssetType.USD: output_files = [ - "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd", + 'outputs/EmbodiedGenData/demo_assets/banana/usd/banana.usd', + 'outputs/EmbodiedGenData/demo_assets/book/usd/book.usd', + 'outputs/EmbodiedGenData/demo_assets/lamp/usd/lamp.usd', + 'outputs/EmbodiedGenData/demo_assets/mug/usd/mug.usd', + 'outputs/EmbodiedGenData/demo_assets/remote_control/usd/remote_control.usd', + "outputs/EmbodiedGenData/demo_assets/rubik's_cube/usd/rubik's_cube.usd", + 'outputs/EmbodiedGenData/demo_assets/table/usd/table.usd', + 'outputs/EmbodiedGenData/demo_assets/vase/usd/vase.usd', ] asset_converter = AssetConverterFactory.create( target_type=AssetType.USD, @@ -776,33 +922,33 @@ def create( for urdf_path, output_file in zip(urdf_paths, output_files): asset_converter.convert(urdf_path, output_file) - urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf" - output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd" - - asset_converter = AssetConverterFactory.create( - target_type=AssetType.USD, - source_type=AssetType.URDF, - rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis - ) - - with asset_converter: - asset_converter.convert(urdf_path, output_file) - - # Convert infinigen urdf to mjcf - urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf" - output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml" - asset_converter = AssetConverterFactory.create( - target_type=AssetType.MJCF, - source_type=AssetType.URDF, - keep_materials=["diffuse"], - ) - with asset_converter: - asset_converter.convert(urdf_path, output_file) - - # Convert infinigen usdc to physics usdc - converter = PhysicsUSDAdder() - with converter: - converter.convert( - usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc", - output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc", - ) + # urdf_path = "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf" + # output_file = "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd" + + # asset_converter = AssetConverterFactory.create( + # target_type=AssetType.USD, + # source_type=AssetType.URDF, + # rotate_wxyz=(0.7071, 0.7071, 0, 0), # rotate 90 deg around the X-axis + # ) + + # with asset_converter: + # asset_converter.convert(urdf_path, output_file) + + # # Convert infinigen urdf to mjcf + # urdf_path = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/export_scene/scene.urdf" + # output_file = "/home/users/xinjie.wang/xinjie/infinigen/outputs/exports/kitchen_i_urdf/mjcf/scene.xml" + # asset_converter = AssetConverterFactory.create( + # target_type=AssetType.MJCF, + # source_type=AssetType.URDF, + # keep_materials=["diffuse"], + # ) + # with asset_converter: + # asset_converter.convert(urdf_path, output_file) + + # # Convert infinigen usdc to physics usdc + # converter = PhysicsUSDAdder() + # with converter: + # converter.convert( + # usd_path="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc/export_scene/export_scene.usdc", + # output_file="/home/users/xinjie.wang/xinjie/infinigen/outputs/usdc_p3/export_scene/export_scene.usdc", + # ) diff --git a/embodied_gen/data/backproject_v2.py b/embodied_gen/data/backproject_v2.py index 420ee39..5908013 100644 --- a/embodied_gen/data/backproject_v2.py +++ b/embodied_gen/data/backproject_v2.py @@ -58,7 +58,16 @@ def _transform_vertices( mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: - """Transform 3D vertices using a projection matrix.""" + """Transforms 3D vertices using a projection matrix. + + Args: + mtx (torch.Tensor): Projection matrix. + pos (torch.Tensor): Vertex positions. + keepdim (bool, optional): If True, keeps the batch dimension. + + Returns: + torch.Tensor: Transformed vertices. + """ t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) if pos.size(-1) == 3: pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) @@ -71,7 +80,17 @@ def _transform_vertices( def _bilinear_interpolation_scattering( image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: - """Bilinear interpolation scattering for grid-based value accumulation.""" + """Performs bilinear interpolation scattering for grid-based value accumulation. + + Args: + image_h (int): Image height. + image_w (int): Image width. + coords (torch.Tensor): Normalized coordinates. + values (torch.Tensor): Values to scatter. + + Returns: + torch.Tensor: Interpolated grid. + """ device = values.device dtype = values.dtype C = values.shape[-1] @@ -135,7 +154,18 @@ def _texture_inpaint_smooth( faces: np.ndarray, uv_map: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: - """Perform texture inpainting using vertex-based color propagation.""" + """Performs texture inpainting using vertex-based color propagation. + + Args: + texture (np.ndarray): Texture image. + mask (np.ndarray): Mask image. + vertices (np.ndarray): Mesh vertices. + faces (np.ndarray): Mesh faces. + uv_map (np.ndarray): UV coordinates. + + Returns: + tuple[np.ndarray, np.ndarray]: Inpainted texture and updated mask. + """ image_h, image_w, C = texture.shape N = vertices.shape[0] @@ -231,29 +261,41 @@ def _texture_inpaint_smooth( class TextureBacker: """Texture baking pipeline for multi-view projection and fusion. - This class performs UV-based texture generation for a 3D mesh using - multi-view color images, depth, and normal information. The pipeline - includes mesh normalization and UV unwrapping, visibility-aware - back-projection, confidence-weighted texture fusion, and inpainting - of missing texture regions. + This class generates UV-based textures for a 3D mesh using multi-view images, + depth, and normal information. It includes mesh normalization, UV unwrapping, + visibility-aware back-projection, confidence-weighted fusion, and inpainting. Args: - camera_params (CameraSetting): Camera intrinsics and extrinsics used - for rendering each view. - view_weights (list[float]): A list of weights for each view, used - to blend confidence maps during texture fusion. - render_wh (tuple[int, int], optional): Resolution (width, height) for - intermediate rendering passes. Defaults to (2048, 2048). - texture_wh (tuple[int, int], optional): Output texture resolution - (width, height). Defaults to (2048, 2048). - bake_angle_thresh (int, optional): Maximum angle (in degrees) between - view direction and surface normal for projection to be considered valid. - Defaults to 75. - mask_thresh (float, optional): Threshold applied to visibility masks - during rendering. Defaults to 0.5. - smooth_texture (bool, optional): If True, apply post-processing (e.g., - blurring) to the final texture. Defaults to True. - inpaint_smooth (bool, optional): If True, apply inpainting to smooth. + camera_params (CameraSetting): Camera intrinsics and extrinsics. + view_weights (list[float]): Weights for each view in texture fusion. + render_wh (tuple[int, int], optional): Intermediate rendering resolution. + texture_wh (tuple[int, int], optional): Output texture resolution. + bake_angle_thresh (int, optional): Max angle for valid projection. + mask_thresh (float, optional): Threshold for visibility masks. + smooth_texture (bool, optional): Apply post-processing to texture. + inpaint_smooth (bool, optional): Apply inpainting smoothing. + + Example: + ```py + from embodied_gen.data.backproject_v2 import TextureBacker + from embodied_gen.data.utils import CameraSetting + import trimesh + from PIL import Image + + camera_params = CameraSetting( + num_images=6, + elevation=[20, -10], + distance=5, + resolution_hw=(2048,2048), + fov=math.radians(30), + device='cuda', + ) + view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02] + mesh = trimesh.load('mesh.obj') + images = [Image.open(f'view_{i}.png') for i in range(6)] + texture_backer = TextureBacker(camera_params, view_weights) + textured_mesh = texture_backer(images, mesh, 'output.obj') + ``` """ def __init__( @@ -283,6 +325,12 @@ def __init__( ) def _lazy_init_render(self, camera_params, mask_thresh): + """Lazily initializes the renderer. + + Args: + camera_params (CameraSetting): Camera settings. + mask_thresh (float): Mask threshold. + """ if self.renderer is None: camera = init_kal_camera(camera_params) mv = camera.view_matrix() # (n 4 4) world2cam @@ -301,6 +349,14 @@ def _lazy_init_render(self, camera_params, mask_thresh): ) def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: + """Normalizes mesh and unwraps UVs. + + Args: + mesh (trimesh.Trimesh): Input mesh. + + Returns: + trimesh.Trimesh: Mesh with normalized vertices and UVs. + """ mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) self.scale, self.center = scale, center @@ -318,6 +374,16 @@ def get_mesh_np_attrs( scale: float = None, center: np.ndarray = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Gets mesh attributes as numpy arrays. + + Args: + mesh (trimesh.Trimesh): Input mesh. + scale (float, optional): Scale factor. + center (np.ndarray, optional): Center offset. + + Returns: + tuple: (vertices, faces, uv_map) + """ vertices = mesh.vertices.copy() faces = mesh.faces.copy() uv_map = mesh.visual.uv.copy() @@ -331,6 +397,14 @@ def get_mesh_np_attrs( return vertices, faces, uv_map def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: + """Computes edge image from depth map. + + Args: + depth_image (torch.Tensor): Depth map. + + Returns: + torch.Tensor: Edge image. + """ depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) @@ -344,6 +418,16 @@ def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: def compute_enhanced_viewnormal( self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor ) -> torch.Tensor: + """Computes enhanced view normals for mesh faces. + + Args: + mv_mtx (torch.Tensor): View matrices. + vertices (torch.Tensor): Mesh vertices. + faces (torch.Tensor): Mesh faces. + + Returns: + torch.Tensor: View normals. + """ rast, _ = self.renderer.compute_dr_raster(vertices, faces) rendered_view_normals = [] for idx in range(len(mv_mtx)): @@ -376,6 +460,18 @@ def compute_enhanced_viewnormal( def back_project( self, image, vis_mask, depth, normal, uv ) -> tuple[torch.Tensor, torch.Tensor]: + """Back-projects image and confidence to UV texture space. + + Args: + image (PIL.Image or np.ndarray): Input image. + vis_mask (torch.Tensor): Visibility mask. + depth (torch.Tensor): Depth map. + normal (torch.Tensor): Normal map. + uv (torch.Tensor): UV coordinates. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Texture and confidence map. + """ image = np.array(image) image = torch.as_tensor(image, device=self.device, dtype=torch.float32) if image.ndim == 2: @@ -418,6 +514,17 @@ def back_project( ) def _scatter_texture(self, uv, data, mask): + """Scatters data to texture using UV coordinates and mask. + + Args: + uv (torch.Tensor): UV coordinates. + data (torch.Tensor): Data to scatter. + mask (torch.Tensor): Mask for valid pixels. + + Returns: + torch.Tensor: Scattered texture. + """ + def __filter_data(data, mask): return data.view(-1, data.shape[-1])[mask] @@ -432,6 +539,15 @@ def __filter_data(data, mask): def fast_bake_texture( self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: + """Fuses multiple textures and confidence maps. + + Args: + textures (list[torch.Tensor]): List of textures. + confidence_maps (list[torch.Tensor]): List of confidence maps. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Fused texture and mask. + """ channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_wh + [channel]).to( self.device @@ -451,6 +567,16 @@ def fast_bake_texture( def uv_inpaint( self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray ) -> np.ndarray: + """Inpaints missing regions in the UV texture. + + Args: + mesh (trimesh.Trimesh): Mesh. + texture (np.ndarray): Texture image. + mask (np.ndarray): Mask image. + + Returns: + np.ndarray: Inpainted texture. + """ if self.inpaint_smooth: vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) texture, mask = _texture_inpaint_smooth( @@ -473,6 +599,15 @@ def compute_texture( colors: list[Image.Image], mesh: trimesh.Trimesh, ) -> trimesh.Trimesh: + """Computes the fused texture for the mesh from multi-view images. + + Args: + colors (list[Image.Image]): List of view images. + mesh (trimesh.Trimesh): Mesh to texture. + + Returns: + tuple[np.ndarray, np.ndarray]: Texture and mask. + """ self._lazy_init_render(self.camera_params, self.mask_thresh) vertices = torch.from_numpy(mesh.vertices).to(self.device).float() @@ -517,7 +652,7 @@ def __call__( Args: colors (list[Image.Image]): List of input view images. mesh (trimesh.Trimesh): Input mesh to be textured. - output_path (str): Path to save the output textured mesh (.obj or .glb). + output_path (str): Path to save the output textured mesh. Returns: trimesh.Trimesh: The textured mesh with UV and texture image. @@ -540,6 +675,11 @@ def __call__( def parse_args(): + """Parses command-line arguments for texture backprojection. + + Returns: + argparse.Namespace: Parsed arguments. + """ parser = argparse.ArgumentParser(description="Backproject texture") parser.add_argument( "--color_path", @@ -636,6 +776,16 @@ def entrypoint( imagesr_model: ImageRealESRGAN = None, **kwargs, ) -> trimesh.Trimesh: + """Entrypoint for texture backprojection from multi-view images. + + Args: + delight_model (DelightingModel, optional): Delighting model. + imagesr_model (ImageRealESRGAN, optional): Super-resolution model. + **kwargs: Additional arguments to override CLI. + + Returns: + trimesh.Trimesh: Textured mesh. + """ args = parse_args() for k, v in kwargs.items(): if hasattr(args, k) and v is not None: diff --git a/embodied_gen/data/convex_decomposer.py b/embodied_gen/data/convex_decomposer.py index 88e6084..73c4a5a 100644 --- a/embodied_gen/data/convex_decomposer.py +++ b/embodied_gen/data/convex_decomposer.py @@ -39,6 +39,22 @@ def decompose_convex_coacd( auto_scale: bool = True, scale_factor: float = 1.0, ) -> None: + """Decomposes a mesh using CoACD and saves the result. + + This function loads a mesh from a file, runs the CoACD algorithm with the + given parameters, optionally scales the resulting convex hulls to match the + original mesh's bounding box, and exports the combined result to a file. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + params: A dictionary of parameters for the CoACD algorithm. + verbose: If True, sets the CoACD log level to 'info'. + auto_scale: If True, automatically computes a scale factor to match the + decomposed mesh's bounding box to the visual mesh's bounding box. + scale_factor: An additional scaling factor applied to the vertices of + the decomposed mesh parts. + """ coacd.set_log_level("info" if verbose else "warn") mesh = trimesh.load(filename, force="mesh") @@ -83,7 +99,38 @@ def decompose_convex_mesh( scale_factor: float = 1.005, verbose: bool = False, ) -> str: - """Decompose a mesh into convex parts using the CoACD algorithm.""" + """Decomposes a mesh into convex parts with retry logic. + + This function serves as a wrapper for `decompose_convex_coacd`, providing + explicit parameters for the CoACD algorithm and implementing a retry + mechanism. If the initial decomposition fails, it attempts again with + `preprocess_mode` set to 'on'. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + threshold: CoACD parameter. See CoACD documentation for details. + max_convex_hull: CoACD parameter. See CoACD documentation for details. + preprocess_mode: CoACD parameter. See CoACD documentation for details. + preprocess_resolution: CoACD parameter. See CoACD documentation for details. + resolution: CoACD parameter. See CoACD documentation for details. + mcts_nodes: CoACD parameter. See CoACD documentation for details. + mcts_iterations: CoACD parameter. See CoACD documentation for details. + mcts_max_depth: CoACD parameter. See CoACD documentation for details. + pca: CoACD parameter. See CoACD documentation for details. + merge: CoACD parameter. See CoACD documentation for details. + seed: CoACD parameter. See CoACD documentation for details. + auto_scale: If True, automatically scale the output to match the input + bounding box. + scale_factor: Additional scaling factor to apply. + verbose: If True, enables detailed logging. + + Returns: + The path to the output file if decomposition is successful. + + Raises: + RuntimeError: If convex decomposition fails after all attempts. + """ coacd.set_log_level("info" if verbose else "warn") if os.path.exists(outfile): @@ -148,9 +195,37 @@ def decompose_convex_mp( verbose: bool = False, auto_scale: bool = True, ) -> str: - """Decompose a mesh into convex parts using the CoACD algorithm in a separate process. + """Decomposes a mesh into convex parts in a separate process. + + This function uses the `multiprocessing` module to run the CoACD algorithm + in a spawned subprocess. This is useful for isolating the decomposition + process to prevent potential memory leaks or crashes in the main process. + It includes a retry mechanism similar to `decompose_convex_mesh`. See https://simulately.wiki/docs/toolkits/ConvexDecomp for details. + + Args: + filename: Path to the input mesh file. + outfile: Path to save the decomposed output mesh. + threshold: CoACD parameter. + max_convex_hull: CoACD parameter. + preprocess_mode: CoACD parameter. + preprocess_resolution: CoACD parameter. + resolution: CoACD parameter. + mcts_nodes: CoACD parameter. + mcts_iterations: CoACD parameter. + mcts_max_depth: CoACD parameter. + pca: CoACD parameter. + merge: CoACD parameter. + seed: CoACD parameter. + verbose: If True, enables detailed logging in the subprocess. + auto_scale: If True, automatically scale the output. + + Returns: + The path to the output file if decomposition is successful. + + Raises: + RuntimeError: If convex decomposition fails after all attempts. """ params = dict( threshold=threshold, diff --git a/embodied_gen/data/differentiable_render.py b/embodied_gen/data/differentiable_render.py index fdd5a26..52a8406 100644 --- a/embodied_gen/data/differentiable_render.py +++ b/embodied_gen/data/differentiable_render.py @@ -66,6 +66,14 @@ def create_mp4_from_images( fps: int = 10, prompt: str = None, ): + """Creates an MP4 video from a list of images. + + Args: + images (list[np.ndarray]): List of images as numpy arrays. + output_path (str): Path to save the MP4 file. + fps (int, optional): Frames per second. Defaults to 10. + prompt (str, optional): Optional text prompt overlay. + """ font = cv2.FONT_HERSHEY_SIMPLEX font_scale = 0.5 font_thickness = 1 @@ -96,6 +104,13 @@ def create_mp4_from_images( def create_gif_from_images( images: list[np.ndarray], output_path: str, fps: int = 10 ) -> None: + """Creates a GIF animation from a list of images. + + Args: + images (list[np.ndarray]): List of images as numpy arrays. + output_path (str): Path to save the GIF file. + fps (int, optional): Frames per second. Defaults to 10. + """ pil_images = [] for image in images: image = image.clip(min=0, max=1) @@ -116,32 +131,47 @@ def create_gif_from_images( class ImageRender(object): - """A differentiable mesh renderer supporting multi-view rendering. + """Differentiable mesh renderer supporting multi-view rendering. - This class wraps a differentiable rasterization using `nvdiffrast` to - render mesh geometry to various maps (normal, depth, alpha, albedo, etc.). + This class wraps differentiable rasterization using `nvdiffrast` to render mesh + geometry to various maps (normal, depth, alpha, albedo, etc.) and supports + saving images and videos. Args: - render_items (list[RenderItems]): A list of rendering targets to - generate (e.g., IMAGE, DEPTH, NORMAL, etc.). - camera_params (CameraSetting): The camera parameters for rendering, - including intrinsic and extrinsic matrices. - recompute_vtx_normal (bool, optional): If True, recomputes - vertex normals from the mesh geometry. Defaults to True. - with_mtl (bool, optional): Whether to load `.mtl` material files - for meshes. Defaults to False. - gen_color_gif (bool, optional): Generate a GIF of rendered - color images. Defaults to False. - gen_color_mp4 (bool, optional): Generate an MP4 video of rendered - color images. Defaults to False. - gen_viewnormal_mp4 (bool, optional): Generate an MP4 video of - view-space normals. Defaults to False. - gen_glonormal_mp4 (bool, optional): Generate an MP4 video of - global-space normals. Defaults to False. - no_index_file (bool, optional): If True, skip saving the `index.json` - summary file. Defaults to False. - light_factor (float, optional): A scalar multiplier for - PBR light intensity. Defaults to 1.0. + render_items (list[RenderItems]): List of rendering targets. + camera_params (CameraSetting): Camera parameters for rendering. + recompute_vtx_normal (bool, optional): Recompute vertex normals. Defaults to True. + with_mtl (bool, optional): Load mesh material files. Defaults to False. + gen_color_gif (bool, optional): Generate GIF of color images. Defaults to False. + gen_color_mp4 (bool, optional): Generate MP4 of color images. Defaults to False. + gen_viewnormal_mp4 (bool, optional): Generate MP4 of view-space normals. Defaults to False. + gen_glonormal_mp4 (bool, optional): Generate MP4 of global-space normals. Defaults to False. + no_index_file (bool, optional): Skip saving index file. Defaults to False. + light_factor (float, optional): PBR light intensity multiplier. Defaults to 1.0. + + Example: + ```py + from embodied_gen.data.differentiable_render import ImageRender + from embodied_gen.data.utils import CameraSetting + from embodied_gen.utils.enum import RenderItems + + camera_params = CameraSetting( + num_images=6, + elevation=[20, -10], + distance=5, + resolution_hw=(512,512), + fov=math.radians(30), + device='cuda', + ) + render_items = [RenderItems.IMAGE.value, RenderItems.DEPTH.value] + renderer = ImageRender( + render_items, + camera_params, + with_mtl=args.with_mtl, + gen_color_mp4=True, + ) + renderer.render_mesh(mesh_path='mesh.obj', output_root='./renders') + ``` """ def __init__( @@ -198,6 +228,14 @@ def render_mesh( uuid: Union[str, List[str]] = None, prompts: List[str] = None, ) -> None: + """Renders one or more meshes and saves outputs. + + Args: + mesh_path (Union[str, List[str]]): Path(s) to mesh files. + output_root (str): Directory to save outputs. + uuid (Union[str, List[str]], optional): Unique IDs for outputs. + prompts (List[str], optional): Text prompts for videos. + """ mesh_path = as_list(mesh_path) if uuid is None: uuid = [os.path.basename(p).split(".")[0] for p in mesh_path] @@ -227,18 +265,15 @@ def render_mesh( def __call__( self, mesh_path: str, output_dir: str, prompt: str = None ) -> dict[str, str]: - """Render a single mesh and return paths to the rendered outputs. - - Processes the input mesh, renders multiple modalities (e.g., normals, - depth, albedo), and optionally saves video or image sequences. + """Renders a single mesh and returns output paths. Args: - mesh_path (str): Path to the mesh file (.obj/.glb). - output_dir (str): Directory to save rendered outputs. - prompt (str, optional): Optional caption prompt for MP4 metadata. + mesh_path (str): Path to mesh file. + output_dir (str): Directory to save outputs. + prompt (str, optional): Caption prompt for MP4 metadata. Returns: - dict[str, str]: A mapping render types to the saved image paths. + dict[str, str]: Mapping of render types to saved image paths. """ try: mesh = import_kaolin_mesh(mesh_path, self.with_mtl) diff --git a/embodied_gen/data/mesh_operator.py b/embodied_gen/data/mesh_operator.py index 4954900..893e459 100644 --- a/embodied_gen/data/mesh_operator.py +++ b/embodied_gen/data/mesh_operator.py @@ -16,17 +16,13 @@ import logging -import multiprocessing as mp -import os from typing import Tuple, Union -import coacd import igraph import numpy as np import pyvista as pv import spaces import torch -import trimesh import utils3d from pymeshfix import _meshfix from tqdm import tqdm diff --git a/embodied_gen/envs/pick_embodiedgen.py b/embodied_gen/envs/pick_embodiedgen.py index b654bcc..a44e5f1 100644 --- a/embodied_gen/envs/pick_embodiedgen.py +++ b/embodied_gen/envs/pick_embodiedgen.py @@ -51,6 +51,33 @@ @register_env("PickEmbodiedGen-v1", max_episode_steps=100) class PickEmbodiedGen(BaseEnv): + """PickEmbodiedGen as gym env example for object pick-and-place tasks. + + This environment simulates a robot interacting with 3D assets in the + embodiedgen generated scene in SAPIEN. It supports multi-environment setups, + dynamic reconfiguration, and hybrid rendering with 3D Gaussian Splatting. + + Example: + Use `gym.make` to create the `PickEmbodiedGen-v1` parallel environment. + ```python + import gymnasium as gym + env = gym.make( + "PickEmbodiedGen-v1", + num_envs=cfg.num_envs, + render_mode=cfg.render_mode, + enable_shadow=cfg.enable_shadow, + layout_file=cfg.layout_file, + control_mode=cfg.control_mode, + camera_cfg=dict( + camera_eye=cfg.camera_eye, + camera_target_pt=cfg.camera_target_pt, + image_hw=cfg.image_hw, + fovy_deg=cfg.fovy_deg, + ), + ) + ``` + """ + SUPPORTED_ROBOTS = ["panda", "panda_wristcam", "fetch"] goal_thresh = 0.0 @@ -63,6 +90,19 @@ def __init__( reconfiguration_freq: int = None, **kwargs, ): + """Initializes the PickEmbodiedGen environment. + + Args: + *args: Variable length argument list for the base class. + robot_uids: The robot(s) to use in the environment. + robot_init_qpos_noise: Noise added to the robot's initial joint + positions. + num_envs: The number of parallel environments to create. + reconfiguration_freq: How often to reconfigure the scene. If None, + it is set based on num_envs. + **kwargs: Additional keyword arguments for environment setup, + including layout_file, replace_objs, enable_grasp, etc. + """ self.robot_init_qpos_noise = robot_init_qpos_noise if reconfiguration_freq is None: if num_envs == 1: @@ -116,6 +156,22 @@ def __init__( def init_env_layouts( layout_file: str, num_envs: int, replace_objs: bool ) -> list[LayoutInfo]: + """Initializes and saves layout files for each environment instance. + + For each environment, this method creates a layout configuration. If + `replace_objs` is True, it generates new object placements for each + subsequent environment. The generated layouts are saved as new JSON + files. + + Args: + layout_file: Path to the base layout JSON file. + num_envs: The number of environments to create layouts for. + replace_objs: If True, generates new object placements for each + environment after the first one using BFS placement. + + Returns: + A list of file paths to the generated layout for each environment. + """ layouts = [] for env_idx in range(num_envs): if replace_objs and env_idx > 0: @@ -136,6 +192,18 @@ def init_env_layouts( def compute_robot_init_pose( layouts: list[str], num_envs: int, z_offset: float = 0.0 ) -> list[list[float]]: + """Computes the initial pose for the robot in each environment. + + Args: + layouts: A list of file paths to the environment layouts. + num_envs: The number of environments. + z_offset: An optional vertical offset to apply to the robot's + position to prevent collisions. + + Returns: + A list of initial poses ([x, y, z, qw, qx, qy, qz]) for the robot + in each environment. + """ robot_pose = [] for env_idx in range(num_envs): layout = json.load(open(layouts[env_idx], "r")) @@ -148,6 +216,11 @@ def compute_robot_init_pose( @property def _default_sim_config(self): + """Returns the default simulation configuration. + + Returns: + The default simulation configuration object. + """ return SimConfig( scene_config=SceneConfig( solver_position_iterations=30, @@ -163,6 +236,11 @@ def _default_sim_config(self): @property def _default_sensor_configs(self): + """Returns the default sensor configurations for the agent. + + Returns: + A list containing the default camera configuration. + """ pose = sapien_utils.look_at(eye=[0.3, 0, 0.6], target=[-0.1, 0, 0.1]) return [ @@ -171,6 +249,11 @@ def _default_sensor_configs(self): @property def _default_human_render_camera_configs(self): + """Returns the default camera configuration for human-friendly rendering. + + Returns: + The default camera configuration for the renderer. + """ pose = sapien_utils.look_at( eye=self.camera_cfg["camera_eye"], target=self.camera_cfg["camera_target_pt"], @@ -187,10 +270,24 @@ def _default_human_render_camera_configs(self): ) def _load_agent(self, options: dict): + """Loads the agent (robot) and a ground plane into the scene. + + Args: + options: A dictionary of options for loading the agent. + """ self.ground = build_ground(self.scene) super()._load_agent(options, sapien.Pose(p=[-10, 0, 10])) def _load_scene(self, options: dict): + """Loads all assets, objects, and the goal site into the scene. + + This method iterates through the layouts for each environment, loads the + specified assets, and adds them to the simulation. It also creates a + kinematic sphere to represent the goal site. + + Args: + options: A dictionary of options for loading the scene. + """ all_objects = [] logger.info(f"Loading EmbodiedGen assets...") for env_idx in range(self.num_envs): @@ -222,6 +319,15 @@ def _load_scene(self, options: dict): self._hidden_objects.append(self.goal_site) def _initialize_episode(self, env_idx: torch.Tensor, options: dict): + """Initializes an episode for a given set of environments. + + This method sets the goal position, resets the robot's joint positions + with optional noise, and sets its root pose. + + Args: + env_idx: A tensor of environment indices to initialize. + options: A dictionary of options for initialization. + """ with torch.device(self.device): b = len(env_idx) goal_xyz = torch.zeros((b, 3)) @@ -256,6 +362,21 @@ def _initialize_episode(self, env_idx: torch.Tensor, options: dict): def render_gs3d_images( self, layouts: list[str], num_envs: int, init_quat: list[float] ) -> dict[str, np.ndarray]: + """Renders background images using a pre-trained Gaussian Splatting model. + + This method pre-renders the static background for each environment from + the perspective of all cameras to be used for hybrid rendering. + + Args: + layouts: A list of file paths to the environment layouts. + num_envs: The number of environments. + init_quat: An initial quaternion to orient the Gaussian Splatting + model. + + Returns: + A dictionary mapping a unique key (e.g., 'camera-env_idx') to the + rendered background image as a numpy array. + """ sim_coord_align = ( torch.tensor(SIM_COORD_ALIGN).to(torch.float32).to(self.device) ) @@ -293,6 +414,15 @@ def render_gs3d_images( return bg_images def render(self): + """Renders the environment based on the configured render_mode. + + Raises: + RuntimeError: If `render_mode` is not set. + NotImplementedError: If the `render_mode` is not supported. + + Returns: + The rendered output, which varies depending on the render mode. + """ if self.render_mode is None: raise RuntimeError("render_mode is not set.") if self.render_mode == "human": @@ -315,6 +445,17 @@ def render(self): def render_rgb_array( self, camera_name: str = None, return_alpha: bool = False ): + """Renders an RGB image from the human-facing render camera. + + Args: + camera_name: The name of the camera to render from. If None, uses + all human render cameras. + return_alpha: Whether to include the alpha channel in the output. + + Returns: + A numpy array representing the rendered image(s). If multiple + cameras are used, the images are tiled. + """ for obj in self._hidden_objects: obj.show_visual() self.scene.update_render( @@ -335,6 +476,11 @@ def render_rgb_array( return tile_images(images) def render_sensors(self): + """Renders images from all on-board sensor cameras. + + Returns: + A tiled image of all sensor outputs as a numpy array. + """ images = [] sensor_images = self.get_sensor_images() for image in sensor_images.values(): @@ -343,6 +489,14 @@ def render_sensors(self): return tile_images(images) def hybrid_render(self): + """Renders a hybrid image by blending simulated foreground with a background. + + The foreground is rendered with an alpha channel and then blended with + the pre-rendered Gaussian Splatting background image. + + Returns: + A torch tensor of the final blended RGB images. + """ fg_images = self.render_rgb_array( return_alpha=True ) # (n_env, h, w, 3) @@ -362,6 +516,16 @@ def hybrid_render(self): return images[..., :3] def evaluate(self): + """Evaluates the current state of the environment. + + Checks for task success criteria such as whether the object is grasped, + placed at the goal, and if the robot is static. + + Returns: + A dictionary containing boolean tensors for various success + metrics, including 'is_grasped', 'is_obj_placed', and overall + 'success'. + """ obj_to_goal_pos = ( self.obj.pose.p ) # self.goal_site.pose.p - self.obj.pose.p @@ -381,10 +545,31 @@ def evaluate(self): ) def _get_obs_extra(self, info: dict): + """Gets extra information for the observation dictionary. + + Args: + info: A dictionary containing evaluation information. + + Returns: + An empty dictionary, as no extra observations are added. + """ return dict() def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict): + """Computes a dense reward for the current step. + + The reward is a composite of reaching, grasping, placing, and + maintaining a static final pose. + + Args: + obs: The current observation. + action: The action taken in the current step. + info: A dictionary containing evaluation information from `evaluate()`. + + Returns: + A tensor containing the dense reward for each environment. + """ tcp_to_obj_dist = torch.linalg.norm( self.obj.pose.p - self.agent.tcp.pose.p, axis=1 ) @@ -417,4 +602,14 @@ def compute_dense_reward(self, obs: any, action: torch.Tensor, info: dict): def compute_normalized_dense_reward( self, obs: any, action: torch.Tensor, info: dict ): + """Computes a dense reward normalized to be between 0 and 1. + + Args: + obs: The current observation. + action: The action taken in the current step. + info: A dictionary containing evaluation information from `evaluate()`. + + Returns: + A tensor containing the normalized dense reward for each environment. + """ return self.compute_dense_reward(obs=obs, action=action, info=info) / 6 diff --git a/embodied_gen/models/delight_model.py b/embodied_gen/models/delight_model.py index 14abb4c..9be7bbb 100644 --- a/embodied_gen/models/delight_model.py +++ b/embodied_gen/models/delight_model.py @@ -40,7 +40,7 @@ class DelightingModel(object): """A model to remove the lighting in image space. This model is encapsulated based on the Hunyuan3D-Delight model - from https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0 # noqa + from `https://huggingface.co/tencent/Hunyuan3D-2/tree/main/hunyuan3d-delight-v2-0` # noqa Attributes: image_guide_scale (float): Weight of image guidance in diffusion process. diff --git a/embodied_gen/models/image_comm_model.py b/embodied_gen/models/image_comm_model.py index 7a8c30c..a04364d 100644 --- a/embodied_gen/models/image_comm_model.py +++ b/embodied_gen/models/image_comm_model.py @@ -38,26 +38,61 @@ class BasePipelineLoader(ABC): + """Abstract base class for loading Hugging Face image generation pipelines. + + Attributes: + device (str): Device to load the pipeline on. + + Methods: + load(): Loads and returns the pipeline. + """ + def __init__(self, device="cuda"): self.device = device @abstractmethod def load(self): + """Load and return the pipeline instance.""" pass class BasePipelineRunner(ABC): + """Abstract base class for running image generation pipelines. + + Attributes: + pipe: The loaded pipeline. + + Methods: + run(prompt, **kwargs): Runs the pipeline with a prompt. + """ + def __init__(self, pipe): self.pipe = pipe @abstractmethod def run(self, prompt: str, **kwargs) -> Image.Image: + """Run the pipeline with the given prompt. + + Args: + prompt (str): Text prompt for image generation. + **kwargs: Additional pipeline arguments. + + Returns: + Image.Image: Generated image(s). + """ pass # ===== SD3.5-medium ===== class SD35Loader(BasePipelineLoader): + """Loader for Stable Diffusion 3.5 medium pipeline.""" + def load(self): + """Load the Stable Diffusion 3.5 medium pipeline. + + Returns: + StableDiffusion3Pipeline: Loaded pipeline. + """ pipe = StableDiffusion3Pipeline.from_pretrained( "stabilityai/stable-diffusion-3.5-medium", torch_dtype=torch.float16, @@ -70,12 +105,25 @@ def load(self): class SD35Runner(BasePipelineRunner): + """Runner for Stable Diffusion 3.5 medium pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Stable Diffusion 3.5 medium. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Cosmos2 ===== class CosmosLoader(BasePipelineLoader): + """Loader for Cosmos2 text-to-image pipeline.""" + def __init__( self, model_id="nvidia/Cosmos-Predict2-2B-Text2Image", @@ -87,6 +135,8 @@ def __init__( self.local_dir = local_dir def _patch(self): + """Patch model and processor for optimized loading.""" + def patch_model(cls): orig = cls.from_pretrained @@ -110,6 +160,11 @@ def new(*args, **kwargs): patch_processor(SiglipProcessor) def load(self): + """Load the Cosmos2 text-to-image pipeline. + + Returns: + Cosmos2TextToImagePipeline: Loaded pipeline. + """ self._patch() snapshot_download( repo_id=self.model_id, @@ -141,7 +196,19 @@ def load(self): class CosmosRunner(BasePipelineRunner): + """Runner for Cosmos2 text-to-image pipeline.""" + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + """Generate images using Cosmos2 pipeline. + + Args: + prompt (str): Text prompt. + negative_prompt (str, optional): Negative prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images @@ -149,7 +216,14 @@ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: # ===== Kolors ===== class KolorsLoader(BasePipelineLoader): + """Loader for Kolors pipeline.""" + def load(self): + """Load the Kolors pipeline. + + Returns: + KolorsPipeline: Loaded pipeline. + """ pipe = KolorsPipeline.from_pretrained( "Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, @@ -164,13 +238,31 @@ def load(self): class KolorsRunner(BasePipelineRunner): + """Runner for Kolors pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Kolors pipeline. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Flux ===== class FluxLoader(BasePipelineLoader): + """Loader for Flux pipeline.""" + def load(self): + """Load the Flux pipeline. + + Returns: + FluxPipeline: Loaded pipeline. + """ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' pipe = FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16 @@ -182,20 +274,50 @@ def load(self): class FluxRunner(BasePipelineRunner): + """Runner for Flux pipeline.""" + def run(self, prompt: str, **kwargs) -> Image.Image: + """Generate images using Flux pipeline. + + Args: + prompt (str): Text prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe(prompt=prompt, **kwargs).images # ===== Chroma ===== class ChromaLoader(BasePipelineLoader): + """Loader for Chroma pipeline.""" + def load(self): + """Load the Chroma pipeline. + + Returns: + ChromaPipeline: Loaded pipeline. + """ return ChromaPipeline.from_pretrained( "lodestones/Chroma", torch_dtype=torch.bfloat16 ).to(self.device) class ChromaRunner(BasePipelineRunner): + """Runner for Chroma pipeline.""" + def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: + """Generate images using Chroma pipeline. + + Args: + prompt (str): Text prompt. + negative_prompt (str, optional): Negative prompt. + **kwargs: Additional arguments. + + Returns: + Image.Image: Generated image(s). + """ return self.pipe( prompt=prompt, negative_prompt=negative_prompt, **kwargs ).images @@ -211,6 +333,22 @@ def run(self, prompt: str, negative_prompt=None, **kwargs) -> Image.Image: def build_hf_image_pipeline(name: str, device="cuda") -> BasePipelineRunner: + """Build a Hugging Face image generation pipeline runner by name. + + Args: + name (str): Name of the pipeline (e.g., "sd35", "cosmos"). + device (str): Device to load the pipeline on. + + Returns: + BasePipelineRunner: Pipeline runner instance. + + Example: + ```py + from embodied_gen.models.image_comm_model import build_hf_image_pipeline + runner = build_hf_image_pipeline("sd35") + images = runner.run(prompt="A robot holding a sign that says 'Hello'") + ``` + """ if name not in PIPELINE_REGISTRY: raise ValueError(f"Unsupported model: {name}") loader_cls, runner_cls = PIPELINE_REGISTRY[name] diff --git a/embodied_gen/models/layout.py b/embodied_gen/models/layout.py index 9613269..8edc279 100644 --- a/embodied_gen/models/layout.py +++ b/embodied_gen/models/layout.py @@ -376,6 +376,21 @@ class LayoutDesigner(object): + """A class for querying GPT-based scene layout reasoning and formatting responses. + + Attributes: + prompt (str): The system prompt for GPT. + verbose (bool): Whether to log responses. + gpt_client (GPTclient): The GPT client instance. + + Methods: + query(prompt, params): Query GPT with a prompt and parameters. + format_response(response): Parse and clean JSON response. + format_response_repair(response): Repair and parse JSON response. + save_output(output, save_path): Save output to file. + __call__(prompt, save_path, params): Query and process output. + """ + def __init__( self, gpt_client: GPTclient, @@ -387,6 +402,15 @@ def __init__( self.gpt_client = gpt_client def query(self, prompt: str, params: dict = None) -> str: + """Query GPT with the system prompt and user prompt. + + Args: + prompt (str): User prompt. + params (dict, optional): GPT parameters. + + Returns: + str: GPT response. + """ full_prompt = self.prompt + f"\n\nInput:\n\"{prompt}\"" response = self.gpt_client.query( @@ -400,6 +424,17 @@ def query(self, prompt: str, params: dict = None) -> str: return response def format_response(self, response: str) -> dict: + """Format and parse GPT response as JSON. + + Args: + response (str): Raw GPT response. + + Returns: + dict: Parsed JSON output. + + Raises: + json.JSONDecodeError: If parsing fails. + """ cleaned = re.sub(r"^```json\s*|\s*```$", "", response.strip()) try: output = json.loads(cleaned) @@ -411,9 +446,23 @@ def format_response(self, response: str) -> dict: return output def format_response_repair(self, response: str) -> dict: + """Repair and parse possibly broken JSON response. + + Args: + response (str): Raw GPT response. + + Returns: + dict: Parsed JSON output. + """ return json_repair.loads(response) def save_output(self, output: dict, save_path: str) -> None: + """Save output dictionary to a file. + + Args: + output (dict): Output data. + save_path (str): Path to save the file. + """ os.makedirs(os.path.dirname(save_path), exist_ok=True) with open(save_path, 'w') as f: json.dump(output, f, indent=4) @@ -421,6 +470,16 @@ def save_output(self, output: dict, save_path: str) -> None: def __call__( self, prompt: str, save_path: str = None, params: dict = None ) -> dict | str: + """Query GPT and process the output. + + Args: + prompt (str): User prompt. + save_path (str, optional): Path to save output. + params (dict, optional): GPT parameters. + + Returns: + dict | str: Output data. + """ response = self.query(prompt, params=params) output = self.format_response_repair(response) self.save_output(output, save_path) if save_path else None @@ -442,6 +501,29 @@ def __call__( def build_scene_layout( task_desc: str, output_path: str = None, gpt_params: dict = None ) -> LayoutInfo: + """Build a 3D scene layout from a natural language task description. + + This function uses GPT-based reasoning to generate a structured scene layout, + including object hierarchy, spatial relations, and style descriptions. + + Args: + task_desc (str): Natural language description of the robotic task. + output_path (str, optional): Path to save the visualized scene tree. + gpt_params (dict, optional): Parameters for GPT queries. + + Returns: + LayoutInfo: Structured layout information for the scene. + + Example: + ```py + from embodied_gen.models.layout import build_scene_layout + layout_info = build_scene_layout( + task_desc="Put the apples on the table on the plate", + output_path="outputs/scene_tree.jpg", + ) + print(layout_info) + ``` + """ layout_relation = LAYOUT_DISASSEMBLER(task_desc, params=gpt_params) layout_tree = LAYOUT_GRAPHER(layout_relation, params=gpt_params) object_mapping = Scene3DItemEnum.object_mapping(layout_relation) diff --git a/embodied_gen/models/segment_model.py b/embodied_gen/models/segment_model.py index ab92c0c..6f54cfa 100644 --- a/embodied_gen/models/segment_model.py +++ b/embodied_gen/models/segment_model.py @@ -48,12 +48,19 @@ class SAMRemover(object): - """Loading SAM models and performing background removal on images. + """Loads SAM models and performs background removal on images. Attributes: checkpoint (str): Path to the model checkpoint. - model_type (str): Type of the SAM model to load (default: "vit_h"). - area_ratio (float): Area ratio filtering small connected components. + model_type (str): Type of the SAM model to load. + area_ratio (float): Area ratio for filtering small connected components. + + Example: + ```py + from embodied_gen.models.segment_model import SAMRemover + remover = SAMRemover(model_type="vit_h") + result = remover("input.jpg", "output.png") + ``` """ def __init__( @@ -78,6 +85,14 @@ def __init__( self.mask_generator = self._load_sam_model(checkpoint) def _load_sam_model(self, checkpoint: str) -> SamAutomaticMaskGenerator: + """Loads the SAM model and returns a mask generator. + + Args: + checkpoint (str): Path to model checkpoint. + + Returns: + SamAutomaticMaskGenerator: Mask generator instance. + """ sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) @@ -89,13 +104,11 @@ def __call__( """Removes the background from an image using the SAM model. Args: - image (Union[str, Image.Image, np.ndarray]): Input image, - can be a file path, PIL Image, or numpy array. - save_path (str): Path to save the output image (default: None). + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. Returns: - Image.Image: The image with background removed, - including an alpha channel. + Image.Image: Image with background removed (RGBA). """ # Convert input to numpy array if isinstance(image, str): @@ -134,6 +147,15 @@ def __call__( class SAMPredictor(object): + """Loads SAM models and predicts segmentation masks from user points. + + Args: + checkpoint (str, optional): Path to model checkpoint. + model_type (str, optional): SAM model type. + binary_thresh (float, optional): Threshold for binary mask. + device (str, optional): Device for inference. + """ + def __init__( self, checkpoint: str = None, @@ -157,12 +179,28 @@ def __init__( self.binary_thresh = binary_thresh def _load_sam_model(self, checkpoint: str) -> SamPredictor: + """Loads the SAM model and returns a predictor. + + Args: + checkpoint (str): Path to model checkpoint. + + Returns: + SamPredictor: Predictor instance. + """ sam = sam_model_registry[self.model_type](checkpoint=checkpoint) sam.to(device=self.device) return SamPredictor(sam) def preprocess_image(self, image: Image.Image) -> np.ndarray: + """Preprocesses input image for SAM prediction. + + Args: + image (Image.Image): Input image. + + Returns: + np.ndarray: Preprocessed image array. + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -178,6 +216,15 @@ def generate_masks( image: np.ndarray, selected_points: list[list[int]], ) -> np.ndarray: + """Generates segmentation masks from selected points. + + Args: + image (np.ndarray): Input image array. + selected_points (list[list[int]]): List of points and labels. + + Returns: + list[tuple[np.ndarray, str]]: List of masks and names. + """ if len(selected_points) == 0: return [] @@ -220,6 +267,15 @@ def generate_masks( def get_segmented_image( self, image: np.ndarray, masks: list[tuple[np.ndarray, str]] ) -> Image.Image: + """Combines masks and returns segmented image with alpha channel. + + Args: + image (np.ndarray): Input image array. + masks (list[tuple[np.ndarray, str]]): List of masks. + + Returns: + Image.Image: Segmented RGBA image. + """ seg_image = Image.fromarray(image, mode="RGB") alpha_channel = np.zeros( (seg_image.height, seg_image.width), dtype=np.uint8 @@ -241,6 +297,15 @@ def __call__( image: Union[str, Image.Image, np.ndarray], selected_points: list[list[int]], ) -> Image.Image: + """Segments image using selected points. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + selected_points (list[list[int]]): List of points and labels. + + Returns: + Image.Image: Segmented RGBA image. + """ image = self.preprocess_image(image) self.predictor.set_image(image) masks = self.generate_masks(image, selected_points) @@ -249,12 +314,32 @@ def __call__( class RembgRemover(object): + """Removes background from images using the rembg library. + + Example: + ```py + from embodied_gen.models.segment_model import RembgRemover + remover = RembgRemover() + result = remover("input.jpg", "output.png") + ``` + """ + def __init__(self): + """Initializes the RembgRemover.""" self.rembg_session = rembg.new_session("u2net") def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ) -> Image.Image: + """Removes background from an image. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: Image with background removed (RGBA). + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -271,7 +356,18 @@ def __call__( class BMGG14Remover(object): + """Removes background using the RMBG-1.4 segmentation model. + + Example: + ```py + from embodied_gen.models.segment_model import BMGG14Remover + remover = BMGG14Remover() + result = remover("input.jpg", "output.png") + ``` + """ + def __init__(self) -> None: + """Initializes the BMGG14Remover.""" self.model = pipeline( "image-segmentation", model="briaai/RMBG-1.4", @@ -281,6 +377,15 @@ def __init__(self) -> None: def __call__( self, image: Union[str, Image.Image, np.ndarray], save_path: str = None ): + """Removes background from an image. + + Args: + image (Union[str, Image.Image, np.ndarray]): Input image. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: Image with background removed. + """ if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): @@ -299,6 +404,16 @@ def __call__( def invert_rgba_pil( image: Image.Image, mask: Image.Image, save_path: str = None ) -> Image.Image: + """Inverts the alpha channel of an RGBA image using a mask. + + Args: + image (Image.Image): Input RGB image. + mask (Image.Image): Mask image for alpha inversion. + save_path (str, optional): Path to save the output image. + + Returns: + Image.Image: RGBA image with inverted alpha. + """ mask = (255 - np.array(mask))[..., None] image_array = np.concatenate([np.array(image), mask], axis=-1) inverted_image = Image.fromarray(image_array, "RGBA") @@ -318,6 +433,20 @@ def get_segmented_image_by_agent( save_path: str = None, mode: Literal["loose", "strict"] = "loose", ) -> Image.Image: + """Segments an image using SAM and rembg, with quality checking. + + Args: + image (Image.Image): Input image. + sam_remover (SAMRemover): SAM-based remover. + rbg_remover (RembgRemover): rembg-based remover. + seg_checker (ImageSegChecker, optional): Quality checker. + save_path (str, optional): Path to save the output image. + mode (Literal["loose", "strict"], optional): Segmentation mode. + + Returns: + Image.Image: Segmented RGBA image. + """ + def _is_valid_seg(raw_img: Image.Image, seg_img: Image.Image) -> bool: if seg_checker is None: return True diff --git a/embodied_gen/models/sr_model.py b/embodied_gen/models/sr_model.py index 40310bb..c8489d9 100644 --- a/embodied_gen/models/sr_model.py +++ b/embodied_gen/models/sr_model.py @@ -39,13 +39,38 @@ class ImageStableSR: - """Super-resolution image upscaler using Stable Diffusion x4 upscaling model from StabilityAI.""" + """Super-resolution image upscaler using Stable Diffusion x4 upscaling model. + + This class wraps the StabilityAI Stable Diffusion x4 upscaler for high-quality + image super-resolution. + + Args: + model_path (str, optional): Path or HuggingFace repo for the model. + device (str, optional): Device for inference. + + Example: + ```py + from embodied_gen.models.sr_model import ImageStableSR + from PIL import Image + + sr_model = ImageStableSR() + img = Image.open("input.png") + upscaled = sr_model(img) + upscaled.save("output.png") + ``` + """ def __init__( self, model_path: str = "stabilityai/stable-diffusion-x4-upscaler", device="cuda", ) -> None: + """Initializes the Stable Diffusion x4 upscaler. + + Args: + model_path (str, optional): Model path or repo. + device (str, optional): Device for inference. + """ from diffusers import StableDiffusionUpscalePipeline self.up_pipeline_x4 = StableDiffusionUpscalePipeline.from_pretrained( @@ -62,6 +87,16 @@ def __call__( prompt: str = "", infer_step: int = 20, ) -> Image.Image: + """Performs super-resolution on the input image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + prompt (str, optional): Text prompt for upscaling. + infer_step (int, optional): Number of inference steps. + + Returns: + Image.Image: Upscaled image. + """ if isinstance(image, np.ndarray): image = Image.fromarray(image) @@ -86,9 +121,26 @@ class ImageRealESRGAN: Attributes: outscale (int): The output image scale factor (e.g., 2, 4). model_path (str): Path to the pre-trained model weights. + + Example: + ```py + from embodied_gen.models.sr_model import ImageRealESRGAN + from PIL import Image + + sr_model = ImageRealESRGAN(outscale=4) + img = Image.open("input.png") + upscaled = sr_model(img) + upscaled.save("output.png") + ``` """ def __init__(self, outscale: int, model_path: str = None) -> None: + """Initializes the RealESRGAN upscaler. + + Args: + outscale (int): Output scale factor. + model_path (str, optional): Path to model weights. + """ # monkey patch to support torchvision>=0.16 import torchvision from packaging import version @@ -122,6 +174,7 @@ def __init__(self, outscale: int, model_path: str = None) -> None: self.model_path = model_path def _lazy_init(self): + """Lazily initializes the RealESRGAN model.""" if self.upsampler is None: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer @@ -145,6 +198,14 @@ def _lazy_init(self): @spaces.GPU def __call__(self, image: Union[Image.Image, np.ndarray]) -> Image.Image: + """Performs super-resolution on the input image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + + Returns: + Image.Image: Upscaled image. + """ self._lazy_init() if isinstance(image, Image.Image): diff --git a/embodied_gen/models/text_model.py b/embodied_gen/models/text_model.py index 0807814..59167bb 100644 --- a/embodied_gen/models/text_model.py +++ b/embodied_gen/models/text_model.py @@ -60,6 +60,11 @@ def download_kolors_weights(local_dir: str = "weights/Kolors") -> None: + """Downloads Kolors model weights from HuggingFace. + + Args: + local_dir (str, optional): Local directory to store weights. + """ logger.info(f"Download kolors weights from huggingface...") os.makedirs(local_dir, exist_ok=True) subprocess.run( @@ -93,6 +98,22 @@ def build_text2img_ip_pipeline( ref_scale: float, device: str = "cuda", ) -> StableDiffusionXLPipelineIP: + """Builds a Stable Diffusion XL pipeline with IP-Adapter for text-to-image generation. + + Args: + ckpt_dir (str): Directory containing model checkpoints. + ref_scale (float): Reference scale for IP-Adapter. + device (str, optional): Device for inference. + + Returns: + StableDiffusionXLPipelineIP: Configured pipeline. + + Example: + ```py + from embodied_gen.models.text_model import build_text2img_ip_pipeline + pipe = build_text2img_ip_pipeline("weights/Kolors", ref_scale=0.3) + ``` + """ download_kolors_weights(ckpt_dir) text_encoder = ChatGLMModel.from_pretrained( @@ -146,6 +167,21 @@ def build_text2img_pipeline( ckpt_dir: str, device: str = "cuda", ) -> StableDiffusionXLPipeline: + """Builds a Stable Diffusion XL pipeline for text-to-image generation. + + Args: + ckpt_dir (str): Directory containing model checkpoints. + device (str, optional): Device for inference. + + Returns: + StableDiffusionXLPipeline: Configured pipeline. + + Example: + ```py + from embodied_gen.models.text_model import build_text2img_pipeline + pipe = build_text2img_pipeline("weights/Kolors") + ``` + """ download_kolors_weights(ckpt_dir) text_encoder = ChatGLMModel.from_pretrained( @@ -185,6 +221,29 @@ def text2img_gen( ip_image_size: int = 512, seed: int = None, ) -> list[Image.Image]: + """Generates images from text prompts using a Stable Diffusion XL pipeline. + + Args: + prompt (str): Text prompt for image generation. + n_sample (int): Number of images to generate. + guidance_scale (float): Guidance scale for diffusion. + pipeline (StableDiffusionXLPipeline | StableDiffusionXLPipelineIP): Pipeline instance. + ip_image (Image.Image | str, optional): Reference image for IP-Adapter. + image_wh (tuple[int, int], optional): Output image size (width, height). + infer_step (int, optional): Number of inference steps. + ip_image_size (int, optional): Size for IP-Adapter image. + seed (int, optional): Random seed. + + Returns: + list[Image.Image]: List of generated images. + + Example: + ```py + from embodied_gen.models.text_model import text2img_gen + images = text2img_gen(prompt="banana", n_sample=3, guidance_scale=7.5) + images[0].save("banana.png") + ``` + """ prompt = PROMPT_KAPPEND.format(object=prompt.strip()) logger.info(f"Processing prompt: {prompt}") diff --git a/embodied_gen/trainer/pono2mesh_trainer.py b/embodied_gen/trainer/pono2mesh_trainer.py index a2fc752..6f04435 100644 --- a/embodied_gen/trainer/pono2mesh_trainer.py +++ b/embodied_gen/trainer/pono2mesh_trainer.py @@ -53,26 +53,31 @@ class Pano2MeshSRPipeline: - """Converting panoramic RGB image into 3D mesh representations, followed by inpainting and mesh refinement. + """Pipeline for converting panoramic RGB images into 3D mesh representations. - This class integrates several key components including: - - Depth estimation from RGB panorama - - Inpainting of missing regions under offsets - - RGB-D to mesh conversion - - Multi-view mesh repair - - 3D Gaussian Splatting (3DGS) dataset generation + This class integrates depth estimation, inpainting, mesh conversion, multi-view mesh repair, + and 3D Gaussian Splatting (3DGS) dataset generation. Args: config (Pano2MeshSRConfig): Configuration object containing model and pipeline parameters. Example: - ```python + ```py + from embodied_gen.trainer.pono2mesh_trainer import Pano2MeshSRPipeline + from embodied_gen.utils.config import Pano2MeshSRConfig + + config = Pano2MeshSRConfig() pipeline = Pano2MeshSRPipeline(config) pipeline(pano_image='example.png', output_dir='./output') ``` """ def __init__(self, config: Pano2MeshSRConfig) -> None: + """Initializes the pipeline with models and camera poses. + + Args: + config (Pano2MeshSRConfig): Configuration object. + """ self.cfg = config self.device = config.device @@ -93,6 +98,7 @@ def __init__(self, config: Pano2MeshSRConfig) -> None: self.kernel = torch.from_numpy(kernel).float().to(self.device) def init_mesh_params(self) -> None: + """Initializes mesh parameters and inpaint mask.""" torch.set_default_device(self.device) self.inpaint_mask = torch.ones( (self.cfg.cubemap_h, self.cfg.cubemap_w), dtype=torch.bool @@ -103,6 +109,14 @@ def init_mesh_params(self) -> None: @staticmethod def read_camera_pose_file(filepath: str) -> np.ndarray: + """Reads a camera pose file and returns the pose matrix. + + Args: + filepath (str): Path to the camera pose file. + + Returns: + np.ndarray: 4x4 camera pose matrix. + """ with open(filepath, "r") as f: values = [float(num) for line in f for num in line.split()] @@ -111,6 +125,14 @@ def read_camera_pose_file(filepath: str) -> np.ndarray: def load_camera_poses( self, trajectory_dir: str ) -> tuple[np.ndarray, list[torch.Tensor]]: + """Loads camera poses from a directory. + + Args: + trajectory_dir (str): Directory containing camera pose files. + + Returns: + tuple[np.ndarray, list[torch.Tensor]]: List of relative camera poses. + """ pose_filenames = sorted( [ fname @@ -148,6 +170,14 @@ def load_camera_poses( def load_inpaint_poses( self, poses: torch.Tensor ) -> dict[int, torch.Tensor]: + """Samples and loads poses for inpainting. + + Args: + poses (torch.Tensor): Tensor of camera poses. + + Returns: + dict[int, torch.Tensor]: Dictionary mapping indices to pose tensors. + """ inpaint_poses = dict() sampled_views = poses[:: self.cfg.inpaint_frame_stride] init_pose = torch.eye(4) @@ -162,6 +192,14 @@ def load_inpaint_poses( return inpaint_poses def project(self, world_to_cam: torch.Tensor): + """Projects the mesh to an image using the given camera pose. + + Args: + world_to_cam (torch.Tensor): World-to-camera transformation matrix. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected RGB image, inpaint mask, and depth map. + """ ( project_image, project_depth, @@ -185,6 +223,14 @@ def project(self, world_to_cam: torch.Tensor): return project_image[:3, ...], inpaint_mask, project_depth def render_pano(self, pose: torch.Tensor): + """Renders a panorama from the mesh using the given pose. + + Args: + pose (torch.Tensor): Camera pose. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: RGB panorama, depth map, and mask. + """ cubemap_list = [] for cubemap_pose in self.cubemap_w2cs: project_pose = cubemap_pose @ pose @@ -213,6 +259,15 @@ def rgbd_to_mesh( world_to_cam: torch.Tensor = None, using_distance_map: bool = True, ) -> None: + """Converts RGB-D images to mesh and updates mesh parameters. + + Args: + rgb (torch.Tensor): RGB image tensor. + depth (torch.Tensor): Depth map tensor. + inpaint_mask (torch.Tensor): Inpaint mask tensor. + world_to_cam (torch.Tensor, optional): Camera pose. + using_distance_map (bool, optional): Whether to use distance map. + """ if world_to_cam is None: world_to_cam = torch.eye(4, dtype=torch.float32).to(self.device) @@ -239,6 +294,15 @@ def rgbd_to_mesh( def get_edge_image_by_depth( self, depth: torch.Tensor, dilate_iter: int = 1 ) -> np.ndarray: + """Computes edge image from depth map. + + Args: + depth (torch.Tensor): Depth map tensor. + dilate_iter (int, optional): Number of dilation iterations. + + Returns: + np.ndarray: Edge image. + """ if isinstance(depth, torch.Tensor): depth = depth.cpu().detach().numpy() @@ -253,6 +317,15 @@ def get_edge_image_by_depth( def mesh_repair_by_greedy_view_selection( self, pose_dict: dict[str, torch.Tensor], output_dir: str ) -> list: + """Repairs mesh by selecting views greedily and inpainting missing regions. + + Args: + pose_dict (dict[str, torch.Tensor]): Dictionary of poses for inpainting. + output_dir (str): Directory to save visualizations. + + Returns: + list: List of inpainted panoramas with poses. + """ inpainted_panos_w_pose = [] while len(pose_dict) > 0: logger.info(f"Repairing mesh left rounds {len(pose_dict)}") @@ -343,6 +416,17 @@ def inpaint_panorama( distances: torch.Tensor, pano_mask: torch.Tensor, ) -> tuple[torch.Tensor]: + """Inpaints missing regions in a panorama. + + Args: + idx (int): Index of the panorama. + colors (torch.Tensor): RGB image tensor. + distances (torch.Tensor): Distance map tensor. + pano_mask (torch.Tensor): Mask tensor. + + Returns: + tuple[torch.Tensor]: Inpainted RGB image, distances, and normals. + """ mask = (pano_mask[None, ..., None] > 0.5).float() mask = mask.permute(0, 3, 1, 2) mask = dilation(mask, kernel=self.kernel) @@ -364,6 +448,14 @@ def inpaint_panorama( def preprocess_pano( self, image: Image.Image | str ) -> tuple[torch.Tensor, torch.Tensor]: + """Preprocesses a panoramic image for mesh generation. + + Args: + image (Image.Image | str): Input image or path. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Preprocessed RGB and depth tensors. + """ if isinstance(image, str): image = Image.open(image) @@ -387,6 +479,17 @@ def preprocess_pano( def pano_to_perpective( self, pano_image: torch.Tensor, pitch: float, yaw: float, fov: float ) -> torch.Tensor: + """Converts a panoramic image to a perspective view. + + Args: + pano_image (torch.Tensor): Panoramic image tensor. + pitch (float): Pitch angle. + yaw (float): Yaw angle. + fov (float): Field of view. + + Returns: + torch.Tensor: Perspective image tensor. + """ rots = dict( roll=0, pitch=pitch, @@ -404,6 +507,14 @@ def pano_to_perpective( return perspective def pano_to_cubemap(self, pano_rgb: torch.Tensor): + """Converts a panoramic RGB image to six cubemap views. + + Args: + pano_rgb (torch.Tensor): Panoramic RGB image tensor. + + Returns: + list: List of cubemap RGB tensors. + """ # Define six canonical cube directions in (pitch, yaw) directions = [ (0, 0), @@ -424,6 +535,11 @@ def pano_to_cubemap(self, pano_rgb: torch.Tensor): return cubemaps_rgb def save_mesh(self, output_path: str) -> None: + """Saves the mesh to a file. + + Args: + output_path (str): Path to save the mesh file. + """ vertices_np = self.vertices.T.cpu().numpy() colors_np = self.colors.T.cpu().numpy() faces_np = self.faces.T.cpu().numpy() @@ -434,6 +550,14 @@ def save_mesh(self, output_path: str) -> None: mesh.export(output_path) def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: + """Converts mesh pose to 3D Gaussian Splatting pose. + + Args: + mesh_pose (torch.Tensor): Mesh pose tensor. + + Returns: + np.ndarray: Converted pose matrix. + """ pose = mesh_pose.clone() pose[0, :] *= -1 pose[1, :] *= -1 @@ -450,6 +574,15 @@ def mesh_pose_to_gs_pose(self, mesh_pose: torch.Tensor) -> np.ndarray: return c2w def __call__(self, pano_image: Image.Image | str, output_dir: str): + """Runs the pipeline to generate mesh and 3DGS data from a panoramic image. + + Args: + pano_image (Image.Image | str): Input panoramic image or path. + output_dir (str): Directory to save outputs. + + Returns: + None + """ self.init_mesh_params() pano_rgb, pano_depth = self.preprocess_pano(pano_image) self.sup_pool = SupInfoPool() diff --git a/embodied_gen/utils/enum.py b/embodied_gen/utils/enum.py index f807f81..807d4da 100644 --- a/embodied_gen/utils/enum.py +++ b/embodied_gen/utils/enum.py @@ -24,11 +24,27 @@ "Scene3DItemEnum", "SpatialRelationEnum", "RobotItemEnum", + "LayoutInfo", + "AssetType", + "SimAssetMapper", ] @dataclass class RenderItems(str, Enum): + """Enumeration of render item types for 3D scenes. + + Attributes: + IMAGE: Color image. + ALPHA: Mask image. + VIEW_NORMAL: View-space normal image. + GLOBAL_NORMAL: World-space normal image. + POSITION_MAP: Position map image. + DEPTH: Depth image. + ALBEDO: Albedo image. + DIFFUSE: Diffuse image. + """ + IMAGE = "image_color" ALPHA = "image_mask" VIEW_NORMAL = "image_view_normal" @@ -41,6 +57,21 @@ class RenderItems(str, Enum): @dataclass class Scene3DItemEnum(str, Enum): + """Enumeration of 3D scene item categories. + + Attributes: + BACKGROUND: Background objects. + CONTEXT: Contextual objects. + ROBOT: Robot entity. + MANIPULATED_OBJS: Objects manipulated by the robot. + DISTRACTOR_OBJS: Distractor objects. + OTHERS: Other objects. + + Methods: + object_list(layout_relation): Returns a list of objects in the scene. + object_mapping(layout_relation): Returns a mapping from object to category. + """ + BACKGROUND = "background" CONTEXT = "context" ROBOT = "robot" @@ -50,6 +81,14 @@ class Scene3DItemEnum(str, Enum): @classmethod def object_list(cls, layout_relation: dict) -> list: + """Returns a list of objects in the scene. + + Args: + layout_relation: Dictionary mapping categories to objects. + + Returns: + List of objects in the scene. + """ return ( [ layout_relation[cls.BACKGROUND.value], @@ -61,6 +100,14 @@ def object_list(cls, layout_relation: dict) -> list: @classmethod def object_mapping(cls, layout_relation): + """Returns a mapping from object to category. + + Args: + layout_relation: Dictionary mapping categories to objects. + + Returns: + Dictionary mapping object names to their category. + """ relation_mapping = { # layout_relation[cls.ROBOT.value]: cls.ROBOT.value, layout_relation[cls.BACKGROUND.value]: cls.BACKGROUND.value, @@ -84,6 +131,15 @@ def object_mapping(cls, layout_relation): @dataclass class SpatialRelationEnum(str, Enum): + """Enumeration of spatial relations for objects in a scene. + + Attributes: + ON: Objects on a surface (e.g., table). + IN: Objects in a container or room. + INSIDE: Objects inside a shelf or rack. + FLOOR: Objects on the floor. + """ + ON = "ON" # objects on the table IN = "IN" # objects in the room INSIDE = "INSIDE" # objects inside the shelf/rack @@ -92,6 +148,14 @@ class SpatialRelationEnum(str, Enum): @dataclass class RobotItemEnum(str, Enum): + """Enumeration of supported robot types. + + Attributes: + FRANKA: Franka robot. + UR5: UR5 robot. + PIPER: Piper robot. + """ + FRANKA = "franka" UR5 = "ur5" PIPER = "piper" @@ -99,6 +163,18 @@ class RobotItemEnum(str, Enum): @dataclass class LayoutInfo(DataClassJsonMixin): + """Data structure for layout information in a 3D scene. + + Attributes: + tree: Hierarchical structure of scene objects. + relation: Spatial relations between objects. + objs_desc: Descriptions of objects. + objs_mapping: Mapping from object names to categories. + assets: Asset file paths for objects. + quality: Quality information for assets. + position: Position coordinates for objects. + """ + tree: dict[str, list] relation: dict[str, str | list[str]] objs_desc: dict[str, str] = field(default_factory=dict) @@ -106,3 +182,64 @@ class LayoutInfo(DataClassJsonMixin): assets: dict[str, str] = field(default_factory=dict) quality: dict[str, str] = field(default_factory=dict) position: dict[str, list[float]] = field(default_factory=dict) + + +@dataclass +class AssetType(str): + """Enumeration for asset types. + + Supported types: + MJCF: MuJoCo XML format. + USD: Universal Scene Description format. + URDF: Unified Robot Description Format. + MESH: Mesh file format. + """ + + MJCF = "mjcf" + USD = "usd" + URDF = "urdf" + MESH = "mesh" + + +class SimAssetMapper: + """Maps simulator names to asset types. + + Provides a mapping from simulator names to their corresponding asset type. + + Example: + ```py + from embodied_gen.utils.enum import SimAssetMapper + asset_type = SimAssetMapper["isaacsim"] + print(asset_type) # Output: 'usd' + ``` + + Methods: + __class_getitem__(key): Returns the asset type for a given simulator name. + """ + + _mapping = dict( + ISAACSIM=AssetType.USD, + ISAACGYM=AssetType.URDF, + MUJOCO=AssetType.MJCF, + GENESIS=AssetType.MJCF, + SAPIEN=AssetType.URDF, + PYBULLET=AssetType.URDF, + ) + + @classmethod + def __class_getitem__(cls, key: str): + """Returns the asset type for a given simulator name. + + Args: + key: Name of the simulator. + + Returns: + AssetType corresponding to the simulator. + + Raises: + KeyError: If the simulator name is not recognized. + """ + key = key.upper() + if key.startswith("SAPIEN"): + key = "SAPIEN" + return cls._mapping[key] diff --git a/embodied_gen/utils/geometry.py b/embodied_gen/utils/geometry.py index 8352ccc..c5dbe85 100644 --- a/embodied_gen/utils/geometry.py +++ b/embodied_gen/utils/geometry.py @@ -45,13 +45,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]: - """Convert a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw). + """Converts a 4x4 transformation matrix to a pose (x, y, z, qx, qy, qz, qw). Args: matrix (np.ndarray): 4x4 transformation matrix. Returns: - List[float]: Pose as [x, y, z, qx, qy, qz, qw]. + list[float]: Pose as [x, y, z, qx, qy, qz, qw]. """ x, y, z = matrix[:3, 3] rot_mat = matrix[:3, :3] @@ -62,13 +62,13 @@ def matrix_to_pose(matrix: np.ndarray) -> list[float]: def pose_to_matrix(pose: list[float]) -> np.ndarray: - """Convert pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix. + """Converts pose (x, y, z, qx, qy, qz, qw) to a 4x4 transformation matrix. Args: - List[float]: Pose as [x, y, z, qx, qy, qz, qw]. + pose (list[float]): Pose as [x, y, z, qx, qy, qz, qw]. Returns: - matrix (np.ndarray): 4x4 transformation matrix. + np.ndarray: 4x4 transformation matrix. """ x, y, z, qx, qy, qz, qw = pose r = R.from_quat([qx, qy, qz, qw]) @@ -82,6 +82,16 @@ def pose_to_matrix(pose: list[float]) -> np.ndarray: def compute_xy_bbox( vertices: np.ndarray, col_x: int = 0, col_y: int = 1 ) -> list[float]: + """Computes the bounding box in XY plane for given vertices. + + Args: + vertices (np.ndarray): Vertex coordinates. + col_x (int, optional): Column index for X. + col_y (int, optional): Column index for Y. + + Returns: + list[float]: [min_x, max_x, min_y, max_y] + """ x_vals = vertices[:, col_x] y_vals = vertices[:, col_y] return x_vals.min(), x_vals.max(), y_vals.min(), y_vals.max() @@ -92,6 +102,16 @@ def has_iou_conflict( placed_boxes: list[list[float]], iou_threshold: float = 0.0, ) -> bool: + """Checks for intersection-over-union conflict between boxes. + + Args: + new_box (list[float]): New box coordinates. + placed_boxes (list[list[float]]): List of placed box coordinates. + iou_threshold (float, optional): IOU threshold. + + Returns: + bool: True if conflict exists, False otherwise. + """ new_min_x, new_max_x, new_min_y, new_max_y = new_box for min_x, max_x, min_y, max_y in placed_boxes: ix1 = max(new_min_x, min_x) @@ -105,7 +125,14 @@ def has_iou_conflict( def with_seed(seed_attr_name: str = "seed"): - """A parameterized decorator that temporarily sets the random seed.""" + """Decorator to temporarily set the random seed for reproducibility. + + Args: + seed_attr_name (str, optional): Name of the seed argument. + + Returns: + function: Decorator function. + """ def decorator(func): @wraps(func) @@ -143,6 +170,20 @@ def compute_convex_hull_path( y_axis: int = 1, z_axis: int = 2, ) -> Path: + """Computes a dense convex hull path for the top surface of a mesh. + + Args: + vertices (np.ndarray): Mesh vertices. + z_threshold (float, optional): Z threshold for top surface. + interp_per_edge (int, optional): Interpolation points per edge. + margin (float, optional): Margin for polygon buffer. + x_axis (int, optional): X axis index. + y_axis (int, optional): Y axis index. + z_axis (int, optional): Z axis index. + + Returns: + Path: Matplotlib path object for the convex hull. + """ top_vertices = vertices[ vertices[:, z_axis] > vertices[:, z_axis].max() - z_threshold ] @@ -170,6 +211,15 @@ def compute_convex_hull_path( def find_parent_node(node: str, tree: dict) -> str | None: + """Finds the parent node of a given node in a tree. + + Args: + node (str): Node name. + tree (dict): Tree structure. + + Returns: + str | None: Parent node name or None. + """ for parent, children in tree.items(): if any(child[0] == node for child in children): return parent @@ -177,6 +227,16 @@ def find_parent_node(node: str, tree: dict) -> str | None: def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool: + """Checks if at least `threshold` corners of a box are inside a hull. + + Args: + hull (Path): Convex hull path. + box (list): Box coordinates [x1, x2, y1, y2]. + threshold (int, optional): Minimum corners inside. + + Returns: + bool: True if enough corners are inside. + """ x1, x2, y1, y2 = box corners = [[x1, y1], [x2, y1], [x1, y2], [x2, y2]] @@ -187,6 +247,15 @@ def all_corners_inside(hull: Path, box: list, threshold: int = 3) -> bool: def compute_axis_rotation_quat( axis: Literal["x", "y", "z"], angle_rad: float ) -> list[float]: + """Computes quaternion for rotation around a given axis. + + Args: + axis (Literal["x", "y", "z"]): Axis of rotation. + angle_rad (float): Rotation angle in radians. + + Returns: + list[float]: Quaternion [x, y, z, w]. + """ if axis.lower() == "x": q = Quaternion(axis=[1, 0, 0], angle=angle_rad) elif axis.lower() == "y": @@ -202,6 +271,15 @@ def compute_axis_rotation_quat( def quaternion_multiply( init_quat: list[float], rotate_quat: list[float] ) -> list[float]: + """Multiplies two quaternions. + + Args: + init_quat (list[float]): Initial quaternion [x, y, z, w]. + rotate_quat (list[float]): Rotation quaternion [x, y, z, w]. + + Returns: + list[float]: Resulting quaternion [x, y, z, w]. + """ qx, qy, qz, qw = init_quat q1 = Quaternion(w=qw, x=qx, y=qy, z=qz) qx, qy, qz, qw = rotate_quat @@ -217,7 +295,17 @@ def check_reachable( min_reach: float = 0.25, max_reach: float = 0.85, ) -> bool: - """Check if the target point is within the reachable range.""" + """Checks if the target point is within the reachable range. + + Args: + base_xyz (np.ndarray): Base position. + reach_xyz (np.ndarray): Target position. + min_reach (float, optional): Minimum reach distance. + max_reach (float, optional): Maximum reach distance. + + Returns: + bool: True if reachable, False otherwise. + """ distance = np.linalg.norm(reach_xyz - base_xyz) return min_reach < distance < max_reach @@ -238,26 +326,31 @@ def bfs_placement( robot_dim: float = 0.12, seed: int = None, ) -> LayoutInfo: - """Place objects in the layout using BFS traversal. + """Places objects in a scene layout using BFS traversal. Args: - layout_file: Path to the JSON file defining the layout structure and assets. - floor_margin: Z-offset for the background object, typically for objects placed on the floor. - beside_margin: Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails. - max_attempts: Maximum number of attempts to find a non-overlapping position for an object. - init_rpy: Initial Roll-Pitch-Yaw rotation rad applied to all object meshes to align the mesh's - coordinate system with the world's (e.g., Z-up). - rotate_objs: If True, apply a random rotation around the Z-axis for manipulated and distractor objects. - rotate_bg: If True, apply a random rotation around the Y-axis for the background object. - rotate_context: If True, apply a random rotation around the Z-axis for the context object. - limit_reach_range: If set, enforce a check that manipulated objects are within the robot's reach range, in meter. - max_orient_diff: If set, enforce a check that manipulated objects are within the robot's orientation range, in degree. - robot_dim: The approximate dimension (e.g., diameter) of the robot for box representation. - seed: Random seed for reproducible placement. + layout_file (str): Path to layout JSON file generated from `layout-cli`. + floor_margin (float, optional): Z-offset for objects placed on the floor. + beside_margin (float, optional): Minimum margin for objects placed 'beside' their parent, used when 'on' placement fails. + max_attempts (int, optional): Max attempts for a non-overlapping placement. + init_rpy (tuple, optional): Initial rotation (rpy). + rotate_objs (bool, optional): Whether to random rotate objects. + rotate_bg (bool, optional): Whether to random rotate background. + rotate_context (bool, optional): Whether to random rotate context asset. + limit_reach_range (tuple[float, float] | None, optional): If set, enforce a check that manipulated objects are within the robot's reach range, in meter. + max_orient_diff (float | None, optional): If set, enforce a check that manipulated objects are within the robot's orientation range, in degree. + robot_dim (float, optional): The approximate robot size. + seed (int, optional): Random seed for reproducible placement. Returns: - A :class:`LayoutInfo` object containing the objects and their final computed 7D poses - ([x, y, z, qx, qy, qz, qw]). + LayoutInfo: Layout information with object poses. + + Example: + ```py + from embodied_gen.utils.geometry import bfs_placement + layout = bfs_placement("scene_layout.json", seed=42) + print(layout.position) + ``` """ layout_info = LayoutInfo.from_dict(json.load(open(layout_file, "r"))) asset_dir = os.path.dirname(layout_file) @@ -478,6 +571,13 @@ def bfs_placement( def compose_mesh_scene( layout_info: LayoutInfo, out_scene_path: str, with_bg: bool = False ) -> None: + """Composes a mesh scene from layout information and saves to file. + + Args: + layout_info (LayoutInfo): Layout information. + out_scene_path (str): Output scene file path. + with_bg (bool, optional): Include background mesh. + """ object_mapping = Scene3DItemEnum.object_mapping(layout_info.relation) scene = trimesh.Scene() for node in layout_info.assets: @@ -505,6 +605,16 @@ def compose_mesh_scene( def compute_pinhole_intrinsics( image_w: int, image_h: int, fov_deg: float ) -> np.ndarray: + """Computes pinhole camera intrinsic matrix from image size and FOV. + + Args: + image_w (int): Image width. + image_h (int): Image height. + fov_deg (float): Field of view in degrees. + + Returns: + np.ndarray: Intrinsic matrix K. + """ fov_rad = np.deg2rad(fov_deg) fx = image_w / (2 * np.tan(fov_rad / 2)) fy = fx # assuming square pixels diff --git a/embodied_gen/utils/gpt_clients.py b/embodied_gen/utils/gpt_clients.py index de435e2..47f5ce2 100644 --- a/embodied_gen/utils/gpt_clients.py +++ b/embodied_gen/utils/gpt_clients.py @@ -45,7 +45,35 @@ class GPTclient: - """A client to interact with the GPT model via OpenAI or Azure API.""" + """A client to interact with GPT models via OpenAI or Azure API. + + Supports text and image prompts, connection checking, and configurable parameters. + + Args: + endpoint (str): API endpoint URL. + api_key (str): API key for authentication. + model_name (str, optional): Model name to use. + api_version (str, optional): API version (for Azure). + check_connection (bool, optional): Whether to check API connection. + verbose (bool, optional): Enable verbose logging. + + Example: + ```sh + export ENDPOINT="https://yfb-openai-sweden.openai.azure.com" + export API_KEY="xxxxxx" + export API_VERSION="2025-03-01-preview" + export MODEL_NAME="yfb-gpt-4o-sweden" + ``` + ```py + from embodied_gen.utils.gpt_clients import GPT_CLIENT + + response = GPT_CLIENT.query("Describe the physics of a falling apple.") + response = GPT_CLIENT.query( + text_prompt="Describe the content in each image." + image_base64=["path/to/image1.png", "path/to/image2.jpg"], + ) + ``` + """ def __init__( self, @@ -82,6 +110,7 @@ def __init__( stop=(stop_after_attempt(10) | stop_after_delay(30)), ) def completion_with_backoff(self, **kwargs): + """Performs a chat completion request with retry/backoff.""" return self.client.chat.completions.create(**kwargs) def query( @@ -91,19 +120,16 @@ def query( system_role: Optional[str] = None, params: Optional[dict] = None, ) -> Optional[str]: - """Queries the GPT model with a text and optional image prompts. + """Queries the GPT model with text and optional image prompts. Args: - text_prompt (str): The main text input that the model responds to. - image_base64 (Optional[List[str]]): A list of image base64 strings - or local image paths or PIL.Image to accompany the text prompt. - system_role (Optional[str]): Optional system-level instructions - that specify the behavior of the assistant. - params (Optional[dict]): Additional parameters for GPT setting. + text_prompt (str): Main text input. + image_base64 (Optional[list[str | Image.Image]], optional): List of image base64 strings, file paths, or PIL Images. + system_role (Optional[str], optional): System-level instructions. + params (Optional[dict], optional): Additional GPT parameters. Returns: - Optional[str]: The response content generated by the model based on - the prompt. Returns `None` if an error occurs. + Optional[str]: Model response content, or None if error. """ if system_role is None: system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa @@ -177,7 +203,11 @@ def query( return response def check_connection(self) -> None: - """Check whether the GPT API connection is working.""" + """Checks whether the GPT API connection is working. + + Raises: + ConnectionError: If connection fails. + """ try: response = self.completion_with_backoff( messages=[ diff --git a/embodied_gen/utils/process_media.py b/embodied_gen/utils/process_media.py index 88eb8e5..8feb7ec 100644 --- a/embodied_gen/utils/process_media.py +++ b/embodied_gen/utils/process_media.py @@ -69,6 +69,40 @@ def render_asset3d( no_index_file: bool = False, with_mtl: bool = True, ) -> list[str]: + """Renders a 3D mesh asset and returns output image paths. + + Args: + mesh_path (str): Path to the mesh file. + output_root (str): Directory to save outputs. + distance (float, optional): Camera distance. + num_images (int, optional): Number of views to render. + elevation (list[float], optional): Camera elevation angles. + pbr_light_factor (float, optional): PBR lighting factor. + return_key (str, optional): Glob pattern for output images. + output_subdir (str, optional): Subdirectory for outputs. + gen_color_mp4 (bool, optional): Generate color MP4 video. + gen_viewnormal_mp4 (bool, optional): Generate view normal MP4. + gen_glonormal_mp4 (bool, optional): Generate global normal MP4. + no_index_file (bool, optional): Skip index file saving. + with_mtl (bool, optional): Use mesh material. + + Returns: + list[str]: List of output image file paths. + + Example: + ```py + from embodied_gen.utils.process_media import render_asset3d + + image_paths = render_asset3d( + mesh_path="path_to_mesh.obj", + output_root="path_to_save_dir", + num_images=6, + elevation=(30, -30), + output_subdir="renders", + no_index_file=True, + ) + ``` + """ input_args = dict( mesh_path=mesh_path, output_root=output_root, @@ -95,6 +129,13 @@ def render_asset3d( def merge_images_video(color_images, normal_images, output_path) -> None: + """Merges color and normal images into a video. + + Args: + color_images (list[np.ndarray]): List of color images. + normal_images (list[np.ndarray]): List of normal images. + output_path (str): Path to save the output video. + """ width = color_images[0].shape[1] combined_video = [ np.hstack([rgb_img[:, : width // 2], normal_img[:, width // 2 :]]) @@ -108,7 +149,13 @@ def merge_images_video(color_images, normal_images, output_path) -> None: def merge_video_video( video_path1: str, video_path2: str, output_path: str ) -> None: - """Merge two videos by the left half and the right half of the videos.""" + """Merges two videos by combining their left and right halves. + + Args: + video_path1 (str): Path to first video. + video_path2 (str): Path to second video. + output_path (str): Path to save the merged video. + """ clip1 = VideoFileClip(video_path1) clip2 = VideoFileClip(video_path2) @@ -127,6 +174,16 @@ def filter_small_connected_components( area_ratio: float, connectivity: int = 8, ) -> np.ndarray: + """Removes small connected components from a binary mask. + + Args: + mask (Union[Image.Image, np.ndarray]): Input mask. + area_ratio (float): Minimum area ratio for components. + connectivity (int, optional): Connectivity for labeling. + + Returns: + np.ndarray: Mask with small components removed. + """ if isinstance(mask, Image.Image): mask = np.array(mask) num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( @@ -152,6 +209,16 @@ def filter_image_small_connected_components( area_ratio: float = 10, connectivity: int = 8, ) -> np.ndarray: + """Removes small connected components from the alpha channel of an image. + + Args: + image (Union[Image.Image, np.ndarray]): Input image. + area_ratio (float, optional): Minimum area ratio. + connectivity (int, optional): Connectivity for labeling. + + Returns: + np.ndarray: Image with filtered alpha channel. + """ if isinstance(image, Image.Image): image = image.convert("RGBA") image = np.array(image) @@ -169,6 +236,24 @@ def combine_images_to_grid( target_wh: tuple[int, int] = (512, 512), image_mode: str = "RGB", ) -> list[Image.Image]: + """Combines multiple images into a grid. + + Args: + images (list[str | Image.Image]): List of image paths or PIL Images. + cat_row_col (tuple[int, int], optional): Grid rows and columns. + target_wh (tuple[int, int], optional): Target image size. + image_mode (str, optional): Image mode. + + Returns: + list[Image.Image]: List containing the grid image. + + Example: + ```py + from embodied_gen.utils.process_media import combine_images_to_grid + grid = combine_images_to_grid(["img1.png", "img2.png"]) + grid[0].save("grid.png") + ``` + """ n_images = len(images) if n_images == 1: return images @@ -196,6 +281,19 @@ def combine_images_to_grid( class SceneTreeVisualizer: + """Visualizes a scene tree layout using networkx and matplotlib. + + Args: + layout_info (LayoutInfo): Layout information for the scene. + + Example: + ```py + from embodied_gen.utils.process_media import SceneTreeVisualizer + visualizer = SceneTreeVisualizer(layout_info) + visualizer.render(save_path="tree.png") + ``` + """ + def __init__(self, layout_info: LayoutInfo) -> None: self.tree = layout_info.tree self.relation = layout_info.relation @@ -274,6 +372,14 @@ def render( dpi=300, title: str = "Scene 3D Hierarchy Tree", ): + """Renders the scene tree and saves to file. + + Args: + save_path (str): Path to save the rendered image. + figsize (tuple, optional): Figure size. + dpi (int, optional): Image DPI. + title (str, optional): Plot image title. + """ node_colors = [ self.role_colors[self._get_node_role(n)] for n in self.G.nodes ] @@ -350,6 +456,14 @@ def render( def load_scene_dict(file_path: str) -> dict: + """Loads a scene description dictionary from a file. + + Args: + file_path (str): Path to the scene description file. + + Returns: + dict: Mapping from scene ID to description. + """ scene_dict = {} with open(file_path, "r", encoding='utf-8') as f: for line in f: @@ -363,12 +477,28 @@ def load_scene_dict(file_path: str) -> dict: def is_image_file(filename: str) -> bool: + """Checks if a filename is an image file. + + Args: + filename (str): Filename to check. + + Returns: + bool: True if image file, False otherwise. + """ mime_type, _ = mimetypes.guess_type(filename) return mime_type is not None and mime_type.startswith('image') def parse_text_prompts(prompts: list[str]) -> list[str]: + """Parses text prompts from a list or file. + + Args: + prompts (list[str]): List of prompts or a file path. + + Returns: + list[str]: List of parsed prompts. + """ if len(prompts) == 1 and prompts[0].endswith(".txt"): with open(prompts[0], "r") as f: prompts = [ @@ -386,13 +516,18 @@ def alpha_blend_rgba( """Alpha blends a foreground RGBA image over a background RGBA image. Args: - fg_image: Foreground image. Can be a file path (str), a PIL Image, - or a NumPy ndarray. - bg_image: Background image. Can be a file path (str), a PIL Image, - or a NumPy ndarray. + fg_image: Foreground image (str, PIL Image, or ndarray). + bg_image: Background image (str, PIL Image, or ndarray). Returns: - A PIL Image representing the alpha-blended result in RGBA mode. + Image.Image: Alpha-blended RGBA image. + + Example: + ```py + from embodied_gen.utils.process_media import alpha_blend_rgba + result = alpha_blend_rgba("fg.png", "bg.png") + result.save("blended.png") + ``` """ if isinstance(fg_image, str): fg_image = Image.open(fg_image) @@ -421,13 +556,11 @@ def check_object_edge_truncated( """Checks if a binary object mask is truncated at the image edges. Args: - mask: A 2D binary NumPy array where nonzero values indicate the object region. - edge_threshold: Number of pixels from each image edge to consider for truncation. - Defaults to 5. + mask (np.ndarray): 2D binary mask. + edge_threshold (int, optional): Edge pixel threshold. Returns: - True if the object is fully enclosed (not truncated). - False if the object touches or crosses any image boundary. + bool: True if object is fully enclosed, False if truncated. """ top = mask[:edge_threshold, :].any() bottom = mask[-edge_threshold:, :].any() @@ -440,6 +573,22 @@ def check_object_edge_truncated( def vcat_pil_images( images: list[Image.Image], image_mode: str = "RGB" ) -> Image.Image: + """Vertically concatenates a list of PIL images. + + Args: + images (list[Image.Image]): List of images. + image_mode (str, optional): Image mode. + + Returns: + Image.Image: Vertically concatenated image. + + Example: + ```py + from embodied_gen.utils.process_media import vcat_pil_images + img = vcat_pil_images([Image.open("a.png"), Image.open("b.png")]) + img.save("vcat.png") + ``` + """ widths, heights = zip(*(img.size for img in images)) total_height = sum(heights) max_width = max(widths) diff --git a/embodied_gen/utils/simulation.py b/embodied_gen/utils/simulation.py index 6925cfb..5ff13b6 100644 --- a/embodied_gen/utils/simulation.py +++ b/embodied_gen/utils/simulation.py @@ -69,6 +69,21 @@ def load_actor_from_urdf( update_mass: bool = False, scale: float | np.ndarray = 1.0, ) -> sapien.pysapien.Entity: + """Load an sapien actor from a URDF file and add it to the scene. + + Args: + scene (sapien.Scene | ManiSkillScene): The simulation scene. + file_path (str): Path to the URDF file. + pose (sapien.Pose | None): Initial pose of the actor. + env_idx (int): Environment index for multi-env setup. + use_static (bool): Whether the actor is static. + update_mass (bool): Whether to update the actor's mass from URDF. + scale (float | np.ndarray): Scale factor for the actor. + + Returns: + sapien.pysapien.Entity: The created actor entity. + """ + def _get_local_pose(origin_tag: ET.Element | None) -> sapien.Pose: local_pose = sapien.Pose(p=[0, 0, 0], q=[1, 0, 0, 0]) if origin_tag is not None: @@ -154,14 +169,17 @@ def load_assets_from_layout_file( init_quat: list[float] = [0, 0, 0, 1], env_idx: int = None, ) -> dict[str, sapien.pysapien.Entity]: - """Load assets from `EmbodiedGen` layout-gen output and create actors in the scene. + """Load assets from an EmbodiedGen layout file and create sapien actors in the scene. Args: - scene (sapien.Scene | ManiSkillScene): The SAPIEN or ManiSkill scene to load assets into. - layout (str): The layout file path. - z_offset (float): Offset to apply to the Z-coordinate of non-context objects. - init_quat (List[float]): Initial quaternion (x, y, z, w) for orientation adjustment. - env_idx (int): Environment index for multi-environment setup. + scene (ManiSkillScene | sapien.Scene): The sapien simulation scene. + layout (str): Path to the embodiedgen layout file. + z_offset (float): Z offset for non-context objects. + init_quat (list[float]): Initial quaternion for orientation. + env_idx (int): Environment index. + + Returns: + dict[str, sapien.pysapien.Entity]: Mapping from object names to actor entities. """ asset_root = os.path.dirname(layout) layout = LayoutInfo.from_dict(json.load(open(layout, "r"))) @@ -206,6 +224,19 @@ def load_mani_skill_robot( control_mode: str = "pd_joint_pos", backend_str: tuple[str, str] = ("cpu", "gpu"), ) -> BaseAgent: + """Load a ManiSkill robot agent into the scene. + + Args: + scene (sapien.Scene | ManiSkillScene): The simulation scene. + layout (LayoutInfo | str): Layout info or path to layout file. + control_freq (int): Control frequency. + robot_init_qpos_noise (float): Noise for initial joint positions. + control_mode (str): Robot control mode. + backend_str (tuple[str, str]): Simulation/render backend. + + Returns: + BaseAgent: The loaded robot agent. + """ from mani_skill.agents import REGISTERED_AGENTS from mani_skill.envs.scene import ManiSkillScene from mani_skill.envs.utils.system.backend import ( @@ -278,14 +309,14 @@ def render_images( ] ] = None, ) -> dict[str, Image.Image]: - """Render images from a given sapien camera. + """Render images from a given SAPIEN camera. Args: - camera (sapien.render.RenderCameraComponent): The camera to render from. - render_keys (List[str]): Types of images to render (e.g., Color, Segmentation). + camera (sapien.render.RenderCameraComponent): Camera to render from. + render_keys (list[str], optional): Types of images to render. Returns: - Dict[str, Image.Image]: Dictionary of rendered images. + dict[str, Image.Image]: Dictionary of rendered images. """ if render_keys is None: render_keys = [ @@ -341,11 +372,33 @@ def render_images( class SapienSceneManager: - """A class to manage SAPIEN simulator.""" + """Manages SAPIEN simulation scenes, cameras, and rendering. + + This class provides utilities for setting up scenes, adding cameras, + stepping simulation, and rendering images. + + Attributes: + sim_freq (int): Simulation frequency. + ray_tracing (bool): Whether to use ray tracing. + device (str): Device for simulation. + renderer (sapien.SapienRenderer): SAPIEN renderer. + scene (sapien.Scene): Simulation scene. + cameras (list): List of camera components. + actors (dict): Mapping of actor names to entities. + + Example see `embodied_gen/scripts/simulate_sapien.py`. + """ def __init__( self, sim_freq: int, ray_tracing: bool, device: str = "cuda" ) -> None: + """Initialize the scene manager. + + Args: + sim_freq (int): Simulation frequency. + ray_tracing (bool): Enable ray tracing. + device (str): Device for simulation. + """ self.sim_freq = sim_freq self.ray_tracing = ray_tracing self.device = device @@ -355,7 +408,11 @@ def __init__( self.actors: dict[str, sapien.pysapien.Entity] = {} def _setup_scene(self) -> sapien.Scene: - """Set up the SAPIEN scene with lighting and ground.""" + """Set up the SAPIEN scene with lighting and ground. + + Returns: + sapien.Scene: The initialized scene. + """ # Ray tracing settings if self.ray_tracing: sapien.render.set_camera_shader_dir("rt") @@ -397,6 +454,18 @@ def step_action( render_keys: list[str], sim_steps_per_control: int = 1, ) -> dict: + """Step the simulation and render images from cameras. + + Args: + agent (BaseAgent): The robot agent. + action (torch.Tensor): Action to apply. + cameras (list): List of camera components. + render_keys (list[str]): Types of images to render. + sim_steps_per_control (int): Simulation steps per control. + + Returns: + dict: Dictionary of rendered frames per camera. + """ agent.set_action(action) frames = defaultdict(list) for _ in range(sim_steps_per_control): @@ -417,13 +486,13 @@ def create_camera( image_hw: tuple[int, int], fovy_deg: float, ) -> sapien.render.RenderCameraComponent: - """Create a single camera in the scene. + """Create a camera in the scene. Args: - cam_name (str): Name of the camera. - pose (sapien.Pose): Camera pose p=(x, y, z), q=(w, x, y, z) - image_hw (Tuple[int, int]): Image resolution (height, width) for cameras. - fovy_deg (float): Field of view in degrees for cameras. + cam_name (str): Camera name. + pose (sapien.Pose): Camera pose. + image_hw (tuple[int, int]): Image resolution (height, width). + fovy_deg (float): Field of view in degrees. Returns: sapien.render.RenderCameraComponent: The created camera. @@ -456,15 +525,15 @@ def initialize_circular_cameras( """Initialize multiple cameras arranged in a circle. Args: - num_cameras (int): Number of cameras to create. - radius (float): Radius of the camera circle. - height (float): Fixed Z-coordinate of the cameras. - target_pt (list[float]): 3D point (x, y, z) that cameras look at. - image_hw (Tuple[int, int]): Image resolution (height, width) for cameras. - fovy_deg (float): Field of view in degrees for cameras. + num_cameras (int): Number of cameras. + radius (float): Circle radius. + height (float): Camera height. + target_pt (list[float]): Target point to look at. + image_hw (tuple[int, int]): Image resolution. + fovy_deg (float): Field of view in degrees. Returns: - List[sapien.render.RenderCameraComponent]: List of created cameras. + list[sapien.render.RenderCameraComponent]: List of cameras. """ angle_step = 2 * np.pi / num_cameras world_up_vec = np.array([0.0, 0.0, 1.0]) @@ -510,6 +579,19 @@ def initialize_circular_cameras( class FrankaPandaGrasper(object): + """Provides grasp planning and control for Franka Panda robot. + + Attributes: + agent (BaseAgent): The robot agent. + robot: The robot instance. + control_freq (float): Control frequency. + control_timestep (float): Control timestep. + joint_vel_limits (float): Joint velocity limits. + joint_acc_limits (float): Joint acceleration limits. + finger_length (float): Length of gripper fingers. + planners: Motion planners for each environment. + """ + def __init__( self, agent: BaseAgent, @@ -518,6 +600,7 @@ def __init__( joint_acc_limits: float = 1.0, finger_length: float = 0.025, ) -> None: + """Initialize the grasper.""" self.agent = agent self.robot = agent.robot self.control_freq = control_freq @@ -553,6 +636,15 @@ def control_gripper( gripper_state: Literal[-1, 1], n_step: int = 10, ) -> np.ndarray: + """Generate gripper control actions. + + Args: + gripper_state (Literal[-1, 1]): Desired gripper state. + n_step (int): Number of steps. + + Returns: + np.ndarray: Array of gripper actions. + """ qpos = self.robot.get_qpos()[0, :-2].cpu().numpy() actions = [] for _ in range(n_step): @@ -571,6 +663,20 @@ def move_to_pose( action_key: str = "position", env_idx: int = 0, ) -> np.ndarray: + """Plan and execute motion to a target pose. + + Args: + pose (sapien.Pose): Target pose. + control_timestep (float): Control timestep. + gripper_state (Literal[-1, 1]): Desired gripper state. + use_point_cloud (bool): Use point cloud for planning. + n_max_step (int): Max number of steps. + action_key (str): Key for action in result. + env_idx (int): Environment index. + + Returns: + np.ndarray: Array of actions to reach the pose. + """ result = self.planners[env_idx].plan_qpos_to_pose( np.concatenate([pose.p, pose.q]), self.robot.get_qpos().cpu().numpy()[0], @@ -608,6 +714,17 @@ def compute_grasp_action( offset: tuple[float, float, float] = [0, 0, -0.05], env_idx: int = 0, ) -> np.ndarray: + """Compute grasp actions for a target actor. + + Args: + actor (sapien.pysapien.Entity): Target actor to grasp. + reach_target_only (bool): Only reach the target pose if True. + offset (tuple[float, float, float]): Offset for reach pose. + env_idx (int): Environment index. + + Returns: + np.ndarray: Array of grasp actions. + """ physx_rigid = actor.components[1] mesh = get_component_mesh(physx_rigid, to_world_frame=True) obb = mesh.bounding_box_oriented diff --git a/embodied_gen/utils/tags.py b/embodied_gen/utils/tags.py index 9c50269..9302331 100644 --- a/embodied_gen/utils/tags.py +++ b/embodied_gen/utils/tags.py @@ -1 +1 @@ -VERSION = "v0.1.5" +VERSION = "v0.1.6" diff --git a/embodied_gen/validators/aesthetic_predictor.py b/embodied_gen/validators/aesthetic_predictor.py index 921f363..6e77449 100644 --- a/embodied_gen/validators/aesthetic_predictor.py +++ b/embodied_gen/validators/aesthetic_predictor.py @@ -27,14 +27,22 @@ class AestheticPredictor: - """Aesthetic Score Predictor. + """Aesthetic Score Predictor using CLIP and a pre-trained MLP. - Checkpoints from https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main + Checkpoints from `https://github.com/christophschuhmann/improved-aesthetic-predictor/tree/main`. Args: - clip_model_dir (str): Path to the directory of the CLIP model. - sac_model_path (str): Path to the pre-trained SAC model. - device (str): Device to use for computation ("cuda" or "cpu"). + clip_model_dir (str, optional): Path to CLIP model directory. + sac_model_path (str, optional): Path to SAC model weights. + device (str, optional): Device for computation ("cuda" or "cpu"). + + Example: + ```py + from embodied_gen.validators.aesthetic_predictor import AestheticPredictor + predictor = AestheticPredictor(device="cuda") + score = predictor.predict("image.png") + print("Aesthetic score:", score) + ``` """ def __init__(self, clip_model_dir=None, sac_model_path=None, device="cpu"): @@ -109,7 +117,7 @@ def _load_sac_model(self, model_path, input_size): return model def predict(self, image_path): - """Predict the aesthetic score for a given image. + """Predicts the aesthetic score for a given image. Args: image_path (str): Path to the image file. diff --git a/embodied_gen/validators/quality_checkers.py b/embodied_gen/validators/quality_checkers.py index 65e236c..0e5ff7e 100644 --- a/embodied_gen/validators/quality_checkers.py +++ b/embodied_gen/validators/quality_checkers.py @@ -40,6 +40,16 @@ class BaseChecker: + """Base class for quality checkers using GPT clients. + + Provides a common interface for querying and validating responses. + Subclasses must implement the `query` method. + + Attributes: + prompt (str): The prompt used for queries. + verbose (bool): Whether to enable verbose logging. + """ + def __init__(self, prompt: str = None, verbose: bool = False) -> None: self.prompt = prompt self.verbose = verbose @@ -70,6 +80,15 @@ def __call__(self, *args, **kwargs) -> tuple[bool, str]: def validate( checkers: list["BaseChecker"], images_list: list[list[str]] ) -> list: + """Validates a list of checkers against corresponding image lists. + + Args: + checkers (list[BaseChecker]): List of checker instances. + images_list (list[list[str]]): List of image path lists. + + Returns: + list: Validation results with overall outcome. + """ assert len(checkers) == len(images_list) results = [] overall_result = True @@ -192,7 +211,7 @@ def query(self, image_paths: list[str]) -> str: class ImageAestheticChecker(BaseChecker): - """A class for evaluating the aesthetic quality of images. + """Evaluates the aesthetic quality of images using a CLIP-based predictor. Attributes: clip_model_dir (str): Path to the CLIP model directory. @@ -200,6 +219,14 @@ class ImageAestheticChecker(BaseChecker): thresh (float): Threshold above which images are considered aesthetically acceptable. verbose (bool): Whether to print detailed log messages. predictor (AestheticPredictor): The model used to predict aesthetic scores. + + Example: + ```py + from embodied_gen.validators.quality_checkers import ImageAestheticChecker + checker = ImageAestheticChecker(thresh=4.5) + flag, score = checker(["image1.png", "image2.png"]) + print("Aesthetic OK:", flag, "Score:", score) + ``` """ def __init__( @@ -227,6 +254,16 @@ def __call__(self, image_paths: list[str], **kwargs) -> bool: class SemanticConsistChecker(BaseChecker): + """Checks semantic consistency between text descriptions and segmented images. + + Uses GPT to evaluate if the image matches the text in object type, geometry, and color. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for consistency evaluation. + verbose (bool): Whether to enable verbose logging. + """ + def __init__( self, gpt_client: GPTclient, @@ -276,6 +313,16 @@ def query(self, text: str, image: list[Image.Image | str]) -> str: class TextGenAlignChecker(BaseChecker): + """Evaluates alignment between text prompts and generated 3D asset images. + + Assesses if the rendered images match the text description in category and geometry. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for alignment evaluation. + verbose (bool): Whether to enable verbose logging. + """ + def __init__( self, gpt_client: GPTclient, @@ -489,6 +536,17 @@ def __call__(self, image_paths: str | Image.Image) -> float: class SemanticMatcher(BaseChecker): + """Matches query text to semantically similar scene descriptions. + + Uses GPT to find the most similar scene IDs from a dictionary. + + Attributes: + gpt_client (GPTclient): GPT client for queries. + prompt (str): Prompt for semantic matching. + verbose (bool): Whether to enable verbose logging. + seed (int): Random seed for selection. + """ + def __init__( self, gpt_client: GPTclient, @@ -543,6 +601,17 @@ def __init__( def query( self, text: str, context: dict, rand: bool = True, params: dict = None ) -> str: + """Queries for semantically similar scene IDs. + + Args: + text (str): Query text. + context (dict): Dictionary of scene descriptions. + rand (bool, optional): Whether to randomly select from top matches. + params (dict, optional): Additional GPT parameters. + + Returns: + str: Matched scene ID. + """ match_list = self.gpt_client.query( self.prompt.format(context=context, text=text), params=params, diff --git a/embodied_gen/validators/urdf_convertor.py b/embodied_gen/validators/urdf_convertor.py index 7341ed7..3f070be 100644 --- a/embodied_gen/validators/urdf_convertor.py +++ b/embodied_gen/validators/urdf_convertor.py @@ -80,6 +80,31 @@ class URDFGenerator(object): + """Generates URDF files for 3D assets with physical and semantic attributes. + + Uses GPT to estimate object properties and generates a URDF file with mesh, friction, mass, and metadata. + + Args: + gpt_client (GPTclient): GPT client for attribute estimation. + mesh_file_list (list[str], optional): Additional mesh files to copy. + prompt_template (str, optional): Prompt template for GPT queries. + attrs_name (list[str], optional): List of attribute names to include. + render_dir (str, optional): Directory for rendered images. + render_view_num (int, optional): Number of views to render. + decompose_convex (bool, optional): Whether to decompose mesh for collision. + rotate_xyzw (list[float], optional): Quaternion for mesh rotation. + + Example: + ```py + from embodied_gen.validators.urdf_convertor import URDFGenerator + from embodied_gen.utils.gpt_clients import GPT_CLIENT + + urdf_gen = URDFGenerator(GPT_CLIENT, render_view_num=4) + urdf_path = urdf_gen(mesh_path="mesh.obj", output_root="output_dir") + print("Generated URDF:", urdf_path) + ``` + """ + def __init__( self, gpt_client: GPTclient, @@ -168,6 +193,14 @@ def __init__( self.rotate_xyzw = rotate_xyzw def parse_response(self, response: str) -> dict[str, any]: + """Parses GPT response to extract asset attributes. + + Args: + response (str): GPT response string. + + Returns: + dict[str, any]: Parsed attributes. + """ lines = response.split("\n") lines = [line.strip() for line in lines if line] category = lines[0].split(": ")[1] @@ -207,11 +240,9 @@ def generate_urdf( Args: input_mesh (str): Path to the input mesh file. - output_dir (str): Directory to store the generated URDF - and processed mesh. - attr_dict (dict): Dictionary containing attributes like height, - mass, and friction coefficients. - output_name (str, optional): Name for the generated URDF and robot. + output_dir (str): Directory to store the generated URDF and mesh. + attr_dict (dict): Dictionary of asset attributes. + output_name (str, optional): Name for the URDF and robot. Returns: str: Path to the generated URDF file. @@ -336,6 +367,16 @@ def get_attr_from_urdf( attr_root: str = ".//link/extra_info", attr_name: str = "scale", ) -> float: + """Extracts an attribute value from a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + attr_root (str, optional): XML path to attribute root. + attr_name (str, optional): Attribute name. + + Returns: + float: Attribute value, or None if not found. + """ if not os.path.exists(urdf_path): raise FileNotFoundError(f"URDF file not found: {urdf_path}") @@ -358,6 +399,13 @@ def get_attr_from_urdf( def add_quality_tag( urdf_path: str, results: list, output_path: str = None ) -> None: + """Adds a quality tag to a URDF file. + + Args: + urdf_path (str): Path to the URDF file. + results (list): List of [checker_name, result] pairs. + output_path (str, optional): Output file path. + """ if output_path is None: output_path = urdf_path @@ -382,6 +430,14 @@ def add_quality_tag( logger.info(f"URDF files saved to {output_path}") def get_estimated_attributes(self, asset_attrs: dict): + """Calculates estimated attributes from asset properties. + + Args: + asset_attrs (dict): Asset attributes. + + Returns: + dict: Estimated attributes (height, mass, mu, category). + """ estimated_attrs = { "height": round( (asset_attrs["min_height"] + asset_attrs["max_height"]) / 2, 4 @@ -403,6 +459,18 @@ def __call__( category: str = "unknown", **kwargs, ): + """Generates a URDF file for a mesh asset. + + Args: + mesh_path (str): Path to mesh file. + output_root (str): Directory for outputs. + text_prompt (str, optional): Prompt for GPT. + category (str, optional): Asset category. + **kwargs: Additional attributes. + + Returns: + str: Path to generated URDF file. + """ if text_prompt is None or len(text_prompt) == 0: text_prompt = self.prompt_template text_prompt = text_prompt.format(category=category.lower()) diff --git a/pyproject.toml b/pyproject.toml index 4f8d645..be3cc60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["embodied_gen"] [project] name = "embodied_gen" -version = "v0.1.5" +version = "v0.1.6" readme = "README.md" license = "Apache-2.0" license-files = ["LICENSE", "NOTICE"] diff --git a/tests/test_examples/test_asset_converter.py b/tests/test_examples/test_asset_converter.py index eeefdbb..6094945 100644 --- a/tests/test_examples/test_asset_converter.py +++ b/tests/test_examples/test_asset_converter.py @@ -4,10 +4,9 @@ from huggingface_hub import snapshot_download from embodied_gen.data.asset_converter import ( AssetConverterFactory, - AssetType, - SimAssetMapper, cvt_embodiedgen_asset_to_anysim, ) +from embodied_gen.utils.enum import AssetType, SimAssetMapper @pytest.fixture(scope="session") @@ -77,7 +76,10 @@ def test_cvt_embodiedgen_asset_to_anysim( ): dst_asset_path = cvt_embodiedgen_asset_to_anysim( urdf_files=[ - "outputs/embodiedgen_assets/demo_assets/remote_control2/result/remote_control.urdf", + "outputs/embodiedgen_assets/demo_assets/remote_control/result/remote_control.urdf", + ], + target_dirs=[ + "outputs/embodiedgen_assets/demo_assets/remote_control/usd/remote_control.usd", ], target_type=SimAssetMapper[simulator_name], source_type=AssetType.MESH,