From a408d59de98aacd765de56a74243c129025928fa Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 07:46:15 +0000 Subject: [PATCH 1/8] add cell complex dataset --- configs/dataset/cell/3D2M.yaml | 33 ++++ configs/model/cell/cwn2.yaml | 38 ++++ configs/transforms/dataset_defaults/3D2M.yaml | 3 + pyproject.toml | 1 + test/pipeline/test_pipeline.py | 4 +- topobench/data/datasets/d_3d2m_dataset.py | 170 ++++++++++++++++++ topobench/data/loaders/__init__.py | 3 + topobench/data/loaders/cell/__init__.py | 99 ++++++++++ topobench/data/loaders/cell/d_3d2m_loader.py | 43 +++++ topobench/data/utils/__init__.py | 8 +- topobench/data/utils/io_utils.py | 166 +++++++++++++++-- topobench/data/utils/utils.py | 44 +++++ 12 files changed, 597 insertions(+), 15 deletions(-) create mode 100644 configs/dataset/cell/3D2M.yaml create mode 100644 configs/model/cell/cwn2.yaml create mode 100644 configs/transforms/dataset_defaults/3D2M.yaml create mode 100644 topobench/data/datasets/d_3d2m_dataset.py create mode 100644 topobench/data/loaders/cell/__init__.py create mode 100644 topobench/data/loaders/cell/d_3d2m_loader.py diff --git a/configs/dataset/cell/3D2M.yaml b/configs/dataset/cell/3D2M.yaml new file mode 100644 index 000000000..0d9a4c515 --- /dev/null +++ b/configs/dataset/cell/3D2M.yaml @@ -0,0 +1,33 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.D3D2MDatasetLoader + parameters: + data_domain: cell + data_type: topological + data_name: 3D2M + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + # Dataset parameters + num_features: [3,3,3] + num_classes: 2 + task: classification + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + +#splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: random #'k-fold' # either "k-fold" or "random" strategies + k: 10 # for "k-fold" Cross-Validation + train_prop: 0.5 # for "random" strategy splitting + +# Dataloader parameters +dataloader_params: + batch_size: 8 # Fixed + num_workers: 1 + pin_memory: False \ No newline at end of file diff --git a/configs/model/cell/cwn2.yaml b/configs/model/cell/cwn2.yaml new file mode 100644 index 000000000..44b7ebcc9 --- /dev/null +++ b/configs/model/cell/cwn2.yaml @@ -0,0 +1,38 @@ +_target_: topobench.model.TBModel + +model_name: cwn +model_domain: cell + +feature_encoder: + _target_: topobench.nn.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder + in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}} + out_channels: 3 + proj_dropout: 0.0 + +backbone: + _target_: topomodelx.nn.cell.cwn.CWN + in_channels_0: ${model.feature_encoder.out_channels} + in_channels_1: ${model.feature_encoder.out_channels} + in_channels_2: ${model.feature_encoder.out_channels} + hid_channels: ${model.feature_encoder.out_channels} + n_layers: 4 + +backbone_wrapper: + _target_: topobench.nn.wrappers.CWNWrapper + _partial_: true + wrapper_name: CWNWrapper + out_channels: ${model.feature_encoder.out_channels} + num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} + +readout: + _target_: topobench.nn.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown + num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider + hidden_dim: ${model.feature_encoder.out_channels} + out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} + pooling_type: sum + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/transforms/dataset_defaults/3D2M.yaml b/configs/transforms/dataset_defaults/3D2M.yaml new file mode 100644 index 000000000..ce39aa53d --- /dev/null +++ b/configs/transforms/dataset_defaults/3D2M.yaml @@ -0,0 +1,3 @@ +# 3D2M dataset needs identity transform to avoid adding random float feature to feature matrix +defaults: + - data_manipulations: identity \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3234ea9e6..456a536a3 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dependencies=[ "topomodelx @ git+https://github.com/pyt-team/TopoModelX.git", "toponetx @ git+https://github.com/pyt-team/TopoNetX.git", "lightning==2.4.0", + "gdown", ] [project.optional-dependencies] diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..eeaadd4ec 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -4,8 +4,8 @@ from test._utils.simplified_pipeline import run -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE +DATASET = "cell/3D2M" # ADD YOUR DATASET HERE +MODELS = ["cell/cwn2"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE class TestPipeline: diff --git a/topobench/data/datasets/d_3d2m_dataset.py b/topobench/data/datasets/d_3d2m_dataset.py new file mode 100644 index 000000000..7bb9a6058 --- /dev/null +++ b/topobench/data/datasets/d_3d2m_dataset.py @@ -0,0 +1,170 @@ +"""Dataset class for 3D2M dataset.""" + +import os +import os.path as osp +import shutil +from contextlib import suppress +from pathlib import Path +from typing import ClassVar + +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset, extract_zip +from torch_geometric.io import fs + +from topobench.data.utils import ( + download_file_from_drive, + read_3d2m_meshes, +) + + +class D3D2MDataset(InMemoryDataset): + r"""Dataset class for 3D2M dataset. + Parameters + ---------- + root : str + Root directory where the dataset will be saved. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset. + Attributes + ---------- + URLS (dict): Dictionary containing the URLs for downloading the dataset. + FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. + RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. + """ + + URLS: ClassVar = { + "3D2M": "https://drive.google.com/file/d/1jxVSmjDQmojh_5LHPLb9RmSIenWPiYnj/view?usp=drive_link", + } + + FILE_FORMAT: ClassVar = { + "3D2M": "zip", + } + + RAW_FILE_NAMES: ClassVar = {} + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + super().__init__( + root, + ) + + out = fs.torch_load(self.processed_paths[0]) + assert len(out) == 3 or len(out) == 4 + + if len(out) == 3: # Backward compatibility. + data, self.slices, self.sizes = out + data_cls = Data + else: + data, self.slices, self.sizes, data_cls = out + + if not isinstance(data, dict): # Backward compatibility. + self.data = data + else: + self.data = data_cls.from_dict(data) + + assert isinstance(self._data, Data) + + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.force_reload={self.force_reload})" + + @property + def raw_dir(self) -> str: + """Return the path to the raw directory of the dataset. + Returns + ------- + str + Path to the raw directory. + """ + return osp.join(self.root, self.name, "raw") + + @property + def processed_dir(self) -> str: + """Return the path to the processed directory of the dataset. + Returns + ------- + str + Path to the processed directory. + """ + + return osp.join(self.root, self.name, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return the raw file names for the dataset. + Returns + ------- + list[str] + List of raw file names. + """ + return ["*.obj", "*.npy"] + + @property + def processed_file_names(self) -> str: + """Return the processed file name for the dataset. + Returns + ------- + str + Processed file name. + """ + return "data.pt" + + + def download(self) -> None: + r"""Download the dataset from a URL and saves it to the raw directory. + Raises: + FileNotFoundError: If the dataset URL is not found. + """ + # Download data from the source + self.url = self.URLS[self.name] + self.file_format = self.FILE_FORMAT[self.name] + download_file_from_drive( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, + file_format=self.file_format, + ) + + # Extract zip file + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + + # Move files' directories from osp.join(folder, name_download) to folder + data_folder_path = Path(osp.join(folder, self.name)) + for subfolder in data_folder_path.iterdir(): + with suppress(Exception): + subfolder.rename(Path(folder) / subfolder.name) + shutil.rmtree(osp.join(folder, self.name)) + + def process(self) -> None: + r"""Handle the data for the dataset. + This method loads the 3D2M data, applies any pre- + processing transformations if specified, and saves the processed data + to the appropriate location. + """ + + # Step 1: Load raw data files + data_list = read_3d2m_meshes(self.raw_dir) + + + # Step 2: collate the graphs + self.data, self.slices = self.collate(data_list) + self._data_list = None # Reset cache. + + # Step 3: save processed data + fs.torch_save( + (self._data.to_dict(), self.slices, {}, self._data.__class__), + self.processed_paths[0], + ) + diff --git a/topobench/data/loaders/__init__.py b/topobench/data/loaders/__init__.py index e1b726bc6..bf76070fa 100755 --- a/topobench/data/loaders/__init__.py +++ b/topobench/data/loaders/__init__.py @@ -1,6 +1,8 @@ """Init file for load module.""" from .base import AbstractLoader +from .cell import * +from .cell import __all__ as cell_all from .graph import * from .graph import __all__ as graph_all from .hypergraph import * @@ -12,6 +14,7 @@ __all__ = [ "AbstractLoader", + *cell_all, *graph_all, *hypergraph_all, *simplicial_all, diff --git a/topobench/data/loaders/cell/__init__.py b/topobench/data/loaders/cell/__init__.py new file mode 100644 index 000000000..a5e5b949a --- /dev/null +++ b/topobench/data/loaders/cell/__init__.py @@ -0,0 +1,99 @@ +"""Init file for cell complex load module with automated loader discovery.""" + +import inspect +from importlib import util +from pathlib import Path +from typing import Any, ClassVar + + +class CellLoaderManager: + """Manages automatic discovery and registration of cell complex loader classes.""" + + # Base class that all cell complex loaders should inherit from (assuming there is one) + # You may need to adjust this based on your actual base loader class + BASE_LOADER_CLASS: ClassVar[type] = object + + @staticmethod + def is_loader_class(obj: Any) -> bool: + """Check if an object is a valid cell complex loader class. + + Parameters + ---------- + obj : Any + The object to check if it's a valid cell complex loader class. + + Returns + ------- + bool + True if the object is a valid cell complex loader class (non-private class + with 'DatasetLoader' in name), False otherwise. + """ + return ( + inspect.isclass(obj) + and not obj.__name__.startswith("_") + and "DatasetLoader" in obj.__name__ + ) + + @classmethod + def discover_loaders(cls, package_path: str) -> dict[str, type[Any]]: + """Dynamically discover all cell complex loader classes in the package. + + Parameters + ---------- + package_path : str + Path to the package's __init__.py file. + + Returns + ------- + Dict[str, Type[Any]] + Dictionary mapping loader class names to their corresponding class objects. + """ + loaders = {} + + # Get the directory containing the loader modules + package_dir = Path(package_path).parent + + # Iterate through all .py files in the directory + for file_path in package_dir.glob("*.py"): + if file_path.stem == "__init__": + continue + + # Import the module + module_name = f"{Path(package_path).stem}.{file_path.stem}" + spec = util.spec_from_file_location(module_name, file_path) + if spec and spec.loader: + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all cell complex loader classes in the module + new_loaders = { + name: obj + for name, obj in inspect.getmembers(module) + if ( + cls.is_loader_class(obj) + and obj.__module__ == module.__name__ + ) + } + loaders.update(new_loaders) + return loaders + + +# Create the loader manager +manager = CellLoaderManager() + +# Automatically discover and populate loaders +CELL_LOADERS = manager.discover_loaders(__file__) + +CELL_LOADERS_list = list(CELL_LOADERS.keys()) + +# Automatically generate __all__ +__all__ = [ + # Loader collections + "CELL_LOADERS", + "CELL_LOADERS_list", + # Individual loader classes + *CELL_LOADERS.keys(), +] + +# For backwards compatibility, create individual imports +locals().update(**CELL_LOADERS) diff --git a/topobench/data/loaders/cell/d_3d2m_loader.py b/topobench/data/loaders/cell/d_3d2m_loader.py new file mode 100644 index 000000000..ee63c6419 --- /dev/null +++ b/topobench/data/loaders/cell/d_3d2m_loader.py @@ -0,0 +1,43 @@ +"""Loader for 3D2M Cell dataset.""" + + +from omegaconf import DictConfig +from torch_geometric.data import Dataset + +from topobench.data.datasets.d_3d2m_dataset import ( + D3D2MDataset, +) +from topobench.data.loaders.base import AbstractLoader + + +class D3D2MDatasetLoader(AbstractLoader): + """Load 3D2M Cell dataset. + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + """ + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """Load 3D2M Cell dataset. + Returns + ------- + Dataset + The loaded 3D2M Cell dataset. + Raises + ------ + RuntimeError + If dataset loading fails. + """ + + dataset = D3D2MDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.parameters, + ) + return dataset diff --git a/topobench/data/utils/__init__.py b/topobench/data/utils/__init__.py index 8793f773e..54f8b63d2 100644 --- a/topobench/data/utils/__init__.py +++ b/topobench/data/utils/__init__.py @@ -12,6 +12,8 @@ load_manual_graph, # noqa: F401 load_simplicial_dataset, # noqa: F401 make_hash, # noqa: F401 + normal, # noqa: F401 + reindex, # noqa: F401 select_neighborhoods_of_interest, # noqa: F401 ) @@ -27,6 +29,8 @@ "ensure_serializable", "select_neighborhoods_of_interest", "data2simplicial", + "normal", + "reindex", # add function name here ] @@ -49,9 +53,10 @@ download_file_from_link, # noqa: F401 load_hypergraph_content_dataset, # noqa: F401 load_hypergraph_pickle_dataset, # noqa: F401 + read_3d2m_meshes, # noqa: F401 + # import function here, add noqa: F401 for PR read_ndim_manifolds, # noqa: F401 read_us_county_demos, # noqa: F401 - # import function here, add noqa: F401 for PR ) io_helper_functions = [ @@ -59,6 +64,7 @@ "load_hypergraph_content_dataset", "read_us_county_demos", "download_file_from_drive", + "read_3d2m_meshes", # add function name here ] diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..9389318ff 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,20 +1,27 @@ """Data IO utilities.""" import json +import os import os.path as osp import pickle +import shutil from urllib.parse import parse_qs, urlparse +import gdown import numpy as np import pandas as pd import requests import torch import torch_geometric -from toponetx.classes import SimplicialComplex +from toponetx.classes import CellComplex, SimplicialComplex from torch_geometric.data import Data from torch_sparse import coalesce -from topobench.data.utils import get_complex_connectivity +from topobench.data.utils import ( + get_complex_connectivity, + normal, + reindex, +) def get_file_id_from_url(url): @@ -70,18 +77,23 @@ def download_file_from_drive( ------ None """ - file_id = get_file_id_from_url(file_link) + if "drive_link" in file_link: + gdown.download(file_link, f"{dataset_name}.{file_format}", fuzzy=True) + shutil.copy(osp.join(os.curdir, f"{dataset_name}.{file_format}"), path_to_save) + os.remove(osp.join(os.curdir, f"{dataset_name}.{file_format}")) + else: + file_id = get_file_id_from_url(file_link) - download_link = f"https://drive.google.com/uc?id={file_id}" - response = requests.get(download_link) + download_link = f"https://drive.google.com/uc?id={file_id}" + response = requests.get(download_link) - output_path = f"{path_to_save}/{dataset_name}.{file_format}" - if response.status_code == 200: - with open(output_path, "wb") as f: - f.write(response.content) - print("Download complete.") - else: - print("Failed to download the file.") + output_path = f"{path_to_save}/{dataset_name}.{file_format}" + if response.status_code == 200: + with open(output_path, "wb") as f: + f.write(response.content) + print("Download complete.") + else: + print("Failed to download the file.") def download_file_from_link( @@ -230,6 +242,136 @@ def read_ndim_manifolds( return data_list +def single_mesh(directory): + """Load 3D2M single dataset. + + Parameters + ---------- + directory : str + Path to the dataset. + + Returns + ------- + torch_geometric.data.Data + Data object of the one mesh for the 3D2M dataset. + """ + files = os.listdir(directory) + obj = [f for f in files if f[-3:] == "obj"] + npy = [f for f in files if f[-3:] == "npy"] + + cell_0 = [] + faces_ind = [] + # Load file with vertices and faces + with open(osp.join(directory,obj[0]), encoding="utf-8") as infile: + lines = infile.read().splitlines() + for line_str in lines: + line = line_str.split(" ") + if line[0] == "v": + cell_0.append(torch.tensor([float(c) for c in line[1:]])) + if line[0] == "f": + try: + faces_ind.append( + [reindex(line[1].split("/")[0]), + reindex(line[2].split("/")[0]), + reindex(line[3].split("/")[0]), + reindex(line[4].split("/")[0])] + ) + except Exception: + faces_ind.append( + [reindex(line[1].split("/")[0]), + reindex(line[2].split("/")[0]), + reindex(line[3].split("/")[0])] + ) + + # build the features x_0 + try: + cell_0_tensor = torch.stack(cell_0).to(torch.float32) + except Exception: + return None + cell_0_tensor /= torch.max(cell_0_tensor) + + cell_1_dict = {} + for fi in faces_ind: + for i in range(len(fi)): + if i < len(fi)-1: + key1 = f"{fi[i]}-{fi[i+1]}" + key2 = f"{fi[i+1]}-{fi[i]}" + else: + key1 = f"{fi[i]}-{fi[0]}" + key2 = f"{fi[0]}-{fi[i]}" + if (key1 in cell_1_dict) or (key2 in cell_1_dict): + try: + cell_1_dict[key1] += 1 + except Exception: + cell_1_dict[key2] += 1 + else: + cell_1_dict[key1] = 1 + + cell_1 = [] + for edge in cell_1_dict: + edge_list = [int(e) for e in edge.split("-")] + cell_1.append(abs(cell_0[edge_list[1]] - cell_0[edge_list[0]])) + + # build the features x_1 + try: + cell_1_tensor = torch.stack(cell_1).to(torch.float32) + except Exception: + return None + cell_1_tensor /= torch.max(cell_1_tensor) + + cell_2 = [normal(cell_0[face[0]],cell_0[face[1]],cell_0[face[2]]) for face in faces_ind] + + # build the features x_2 + try: + cell_2_tensor = torch.stack(cell_2).to(torch.float32) + except Exception: + return None + + cx = CellComplex() + + # Insert all cells + for face in faces_ind: + cx.add_cell(face,2) + for edge in cell_1_dict: + cx.add_cell([int(v) for v in edge.split("-")], 1) + for n in range(len(cell_0)): + cx.add_node(n) + + # Construct the connectivity matrices + inc_dict = get_complex_connectivity(cx, 2, signed=False) + + y = torch.tensor([int("Female" in obj[0])]) + + y1 = torch.tensor(np.load(osp.join(directory,npy[0]))).to(torch.float32) + + data = Data(x_0=cell_0_tensor, x_1=cell_1_tensor, x_2=cell_2_tensor, y=y, y1=y1, **inc_dict) + + return data + + +def read_3d2m_meshes(folder): + """Load 3D2M dataset. + + Parameters + ---------- + folder : str + Path to the dataset. + + Returns + ------- + List + List wih data objects of the complex of the 3D2M dataset. + """ + directories_list = os.listdir(folder) + data_list = [] + for directory in directories_list: + data = single_mesh(osp.join(folder,directory)) + if isinstance(data, Data): + data_list.append(data) + + return data_list + + def read_us_county_demos(path, year=2012, y_col="Election"): """Load US County Demos dataset. diff --git a/topobench/data/utils/utils.py b/topobench/data/utils/utils.py index 2bdd12d84..2942e9d29 100755 --- a/topobench/data/utils/utils.py +++ b/topobench/data/utils/utils.py @@ -860,3 +860,47 @@ def find_tetrahedrons(incidence_1, incidence_2, incidence_3): for i in unique_tetrahedrons ] return tetrahedron_list + + +def normal(v1, v2, v3): + """ + Normal vector of a face from three vertices in the face. + + Parameters + ---------- + v1 : torch.Tensor + Coordinates of a vertex. + v2 : torch.Tensor + Coordinates of a vertex. + v3 : torch.Tensor + Coordinates of a vertex. + + Returns + ------- + torch.Tensor + The normal of a face. + """ + N = torch.linalg.cross(v2-v1, v3-v1) + return N / torch.norm(N) + + +def reindex(x): + """ + Reindexing index that start with 1. + + Parameters + ---------- + x : string + String of an index. + + Returns + ------- + int + Index. + """ + if isinstance(x,str): + return int(x) - 1 + elif isinstance(x, int): + return x - 1 + else: + raise Exception From 123012bed8e89209f41b13a3e888c885b91adf34 Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 08:52:44 +0000 Subject: [PATCH 2/8] add dataset in README --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index 7e01eb151..f9551dcb7 100755 --- a/README.md +++ b/README.md @@ -386,6 +386,13 @@ Specially useful in pre-processing steps, these are the general data manipulatio | --- | --- | --- | --- | | Mantra | Classification, Multi-label Classification | Predict topological attributes of manifold triangulations | [Source](https://github.com/aidos-lab/MANTRA) (This project includes third-party datasets. See third_party_licenses.txt for licensing information.) | + +### Cellular Complexes +| Dataset | Task | Description | Reference | +| --- | --- | --- | --- | +| 3D2M | Classification | complex-level dataset. | [Source](https://arxiv.org/pdf/2410.07415) | + + ### Hypergraph | Dataset | Task | Description | Reference | | --- | --- | --- | --- | From 639f99536a95104f6175aadd2e1e6a7f2ef3decf Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 11:15:16 +0000 Subject: [PATCH 3/8] change pipeline test --- configs/dataset/cell/3D2M.yaml | 1 + test/pipeline/test_pipeline.py | 4 ++++ topobench/data/datasets/d_3d2m_dataset.py | 3 +-- topobench/data/utils/io_utils.py | 8 ++++++-- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/configs/dataset/cell/3D2M.yaml b/configs/dataset/cell/3D2M.yaml index 0d9a4c515..223d35cf5 100644 --- a/configs/dataset/cell/3D2M.yaml +++ b/configs/dataset/cell/3D2M.yaml @@ -6,6 +6,7 @@ loader: data_type: topological data_name: 3D2M data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + num_slices: null # Dataset parameters parameters: diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index eeaadd4ec..83b3ca310 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -24,6 +24,10 @@ def test_pipeline(self): overrides=[ f"model={MODEL}", f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION + "dataset.loader.parameters.num_slices=24", + "dataset.split_params.split_type=k-fold", + "dataset.split_params.k=2", + "dataset.dataloader_params.batch_size=1", "trainer.max_epochs=2", "trainer.min_epochs=1", "trainer.check_val_every_n_epoch=1", diff --git a/topobench/data/datasets/d_3d2m_dataset.py b/topobench/data/datasets/d_3d2m_dataset.py index 7bb9a6058..e3ced1eb6 100644 --- a/topobench/data/datasets/d_3d2m_dataset.py +++ b/topobench/data/datasets/d_3d2m_dataset.py @@ -155,8 +155,7 @@ def process(self) -> None: """ # Step 1: Load raw data files - data_list = read_3d2m_meshes(self.raw_dir) - + data_list = read_3d2m_meshes(self.raw_dir, self.parameters.num_slices) # Step 2: collate the graphs self.data, self.slices = self.collate(data_list) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 9389318ff..d0fdd0ac0 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -349,20 +349,24 @@ def single_mesh(directory): return data -def read_3d2m_meshes(folder): +def read_3d2m_meshes(folder, num_slices): """Load 3D2M dataset. Parameters ---------- folder : str Path to the dataset. + num_slices: int + Number of data objects to process and add to the final list Returns ------- List List wih data objects of the complex of the 3D2M dataset. """ - directories_list = os.listdir(folder) + directories = os.listdir(folder) + directories_list = directories[:num_slices] if num_slices else directories + data_list = [] for directory in directories_list: data = single_mesh(osp.join(folder,directory)) From 182a03113bbe219220113ee25b71ab456e081e6b Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 14:31:39 +0000 Subject: [PATCH 4/8] fix default transforms config --- configs/dataset/cell/3D2M.yaml | 2 +- configs/transforms/dataset_defaults/3D2M.yaml | 3 ++- test/pipeline/test_pipeline.py | 2 +- topobench/data/datasets/d_3d2m_dataset.py | 2 +- topobench/data/utils/io_utils.py | 6 +++--- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/configs/dataset/cell/3D2M.yaml b/configs/dataset/cell/3D2M.yaml index 223d35cf5..98f5e090c 100644 --- a/configs/dataset/cell/3D2M.yaml +++ b/configs/dataset/cell/3D2M.yaml @@ -6,7 +6,7 @@ loader: data_type: topological data_name: 3D2M data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} - num_slices: null + num_data: null # Dataset parameters parameters: diff --git a/configs/transforms/dataset_defaults/3D2M.yaml b/configs/transforms/dataset_defaults/3D2M.yaml index ce39aa53d..8db045b1d 100644 --- a/configs/transforms/dataset_defaults/3D2M.yaml +++ b/configs/transforms/dataset_defaults/3D2M.yaml @@ -1,3 +1,4 @@ # 3D2M dataset needs identity transform to avoid adding random float feature to feature matrix defaults: - - data_manipulations: identity \ No newline at end of file + - data_manipulations: identity + - liftings@_here_: ${get_required_lifting:cell,${model}} diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 83b3ca310..69c82ba12 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -24,7 +24,7 @@ def test_pipeline(self): overrides=[ f"model={MODEL}", f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION - "dataset.loader.parameters.num_slices=24", + "dataset.loader.parameters.num_data=24", "dataset.split_params.split_type=k-fold", "dataset.split_params.k=2", "dataset.dataloader_params.batch_size=1", diff --git a/topobench/data/datasets/d_3d2m_dataset.py b/topobench/data/datasets/d_3d2m_dataset.py index e3ced1eb6..f9b32e027 100644 --- a/topobench/data/datasets/d_3d2m_dataset.py +++ b/topobench/data/datasets/d_3d2m_dataset.py @@ -155,7 +155,7 @@ def process(self) -> None: """ # Step 1: Load raw data files - data_list = read_3d2m_meshes(self.raw_dir, self.parameters.num_slices) + data_list = read_3d2m_meshes(self.raw_dir, self.parameters.num_data) # Step 2: collate the graphs self.data, self.slices = self.collate(data_list) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index d0fdd0ac0..8c5fc2980 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -349,14 +349,14 @@ def single_mesh(directory): return data -def read_3d2m_meshes(folder, num_slices): +def read_3d2m_meshes(folder, num_data): """Load 3D2M dataset. Parameters ---------- folder : str Path to the dataset. - num_slices: int + num_data: int Number of data objects to process and add to the final list Returns @@ -365,7 +365,7 @@ def read_3d2m_meshes(folder, num_slices): List wih data objects of the complex of the 3D2M dataset. """ directories = os.listdir(folder) - directories_list = directories[:num_slices] if num_slices else directories + directories_list = directories[:num_data] if num_data else directories data_list = [] for directory in directories_list: From b54ce60c0df8bd838c51649683ac33674f955240 Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 16:00:30 +0000 Subject: [PATCH 5/8] add default config file --- configs/transforms/liftings/cell2graph.yaml | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 configs/transforms/liftings/cell2graph.yaml diff --git a/configs/transforms/liftings/cell2graph.yaml b/configs/transforms/liftings/cell2graph.yaml new file mode 100644 index 000000000..92ec7dff6 --- /dev/null +++ b/configs/transforms/liftings/cell2graph.yaml @@ -0,0 +1,2 @@ +defaults: + - /transforms/liftings: null \ No newline at end of file From 85d0fa98504ea152d807cdd76d416b44ee2e7eab Mon Sep 17 00:00:00 2001 From: gyaguilar Date: Tue, 25 Nov 2025 17:37:30 +0000 Subject: [PATCH 6/8] Change config file name --- .../liftings/{cell2graph.yaml => cell2graph_default.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename configs/transforms/liftings/{cell2graph.yaml => cell2graph_default.yaml} (100%) diff --git a/configs/transforms/liftings/cell2graph.yaml b/configs/transforms/liftings/cell2graph_default.yaml similarity index 100% rename from configs/transforms/liftings/cell2graph.yaml rename to configs/transforms/liftings/cell2graph_default.yaml From 3292b3feca869f2430c4e6d8742af0f8dfdc95bf Mon Sep 17 00:00:00 2001 From: ixime Date: Tue, 25 Nov 2025 18:30:14 +0000 Subject: [PATCH 7/8] add edge_index --- topobench/data/utils/io_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 8c5fc2980..ab58648ee 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -340,11 +340,14 @@ def single_mesh(directory): # Construct the connectivity matrices inc_dict = get_complex_connectivity(cx, 2, signed=False) + edge_index = torch.Tensor(sorted(list(cx.edges))).T.long() + y = torch.tensor([int("Female" in obj[0])]) y1 = torch.tensor(np.load(osp.join(directory,npy[0]))).to(torch.float32) - data = Data(x_0=cell_0_tensor, x_1=cell_1_tensor, x_2=cell_2_tensor, y=y, y1=y1, **inc_dict) + data = Data(x_0=cell_0_tensor, x_1=cell_1_tensor, x_2=cell_2_tensor, + x=cell_0_tensor, edge_index=edge_index, y=y, y1=y1, **inc_dict) return data From d2f13913ea9ad4bf721e83d77ddea94d5d329f67 Mon Sep 17 00:00:00 2001 From: gyaguilar Date: Tue, 25 Nov 2025 18:59:18 +0000 Subject: [PATCH 8/8] Fix test test_datasetloaders --- test/data/load/test_datasetloaders.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..f26b33ce5 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -78,11 +78,18 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]: job_name="run" ): print('Current config file: ', config_file) - parameters = hydra.compose( - config_name="run.yaml", - overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"], - return_hydra_config=True, - ) + if data_domain == "cell": + parameters = hydra.compose( + config_name="run.yaml", + overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat", "dataset.loader.parameters.num_data=24"], + return_hydra_config=True, + ) + else: + parameters = hydra.compose( + config_name="run.yaml", + overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"], + return_hydra_config=True, + ) dataset_loader = hydra.utils.instantiate(parameters.dataset.loader) print(repr(dataset_loader))