Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,13 @@ Specially useful in pre-processing steps, these are the general data manipulatio
| --- | --- | --- | --- |
| Mantra | Classification, Multi-label Classification | Predict topological attributes of manifold triangulations | [Source](https://github.com/aidos-lab/MANTRA) (This project includes third-party datasets. See third_party_licenses.txt for licensing information.) |


### Cellular Complexes
| Dataset | Task | Description | Reference |
| --- | --- | --- | --- |
| 3D2M | Classification | complex-level dataset. | [Source](https://arxiv.org/pdf/2410.07415) |


### Hypergraph
| Dataset | Task | Description | Reference |
| --- | --- | --- | --- |
Expand Down
34 changes: 34 additions & 0 deletions configs/dataset/cell/3D2M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Dataset loader config
loader:
_target_: topobench.data.loaders.D3D2MDatasetLoader
parameters:
data_domain: cell
data_type: topological
data_name: 3D2M
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
num_data: null

# Dataset parameters
parameters:
# Dataset parameters
num_features: [3,3,3]
num_classes: 2
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph

#splits
split_params:
learning_setting: inductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting

# Dataloader parameters
dataloader_params:
batch_size: 8 # Fixed
num_workers: 1
pin_memory: False
38 changes: 38 additions & 0 deletions configs/model/cell/cwn2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
_target_: topobench.model.TBModel

model_name: cwn
model_domain: cell

feature_encoder:
_target_: topobench.nn.encoders.${model.feature_encoder.encoder_name}
encoder_name: AllCellFeatureEncoder
in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}}
out_channels: 3
proj_dropout: 0.0

backbone:
_target_: topomodelx.nn.cell.cwn.CWN
in_channels_0: ${model.feature_encoder.out_channels}
in_channels_1: ${model.feature_encoder.out_channels}
in_channels_2: ${model.feature_encoder.out_channels}
hid_channels: ${model.feature_encoder.out_channels}
n_layers: 4

backbone_wrapper:
_target_: topobench.nn.wrappers.CWNWrapper
_partial_: true
wrapper_name: CWNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}}

readout:
_target_: topobench.nn.readouts.${model.readout.readout_name}
readout_name: PropagateSignalDown # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown
num_cell_dimensions: ${infer_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider
hidden_dim: ${model.feature_encoder.out_channels}
out_channels: ${dataset.parameters.num_classes}
task_level: ${dataset.parameters.task_level}
pooling_type: sum

# compile model for faster training with pytorch 2.0
compile: false
4 changes: 4 additions & 0 deletions configs/transforms/dataset_defaults/3D2M.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# 3D2M dataset needs identity transform to avoid adding random float feature to feature matrix
defaults:
- data_manipulations: identity
- liftings@_here_: ${get_required_lifting:cell,${model}}
2 changes: 2 additions & 0 deletions configs/transforms/liftings/cell2graph_default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
defaults:
- /transforms/liftings: null
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies=[
"topomodelx @ git+https://github.com/pyt-team/TopoModelX.git",
"toponetx @ git+https://github.com/pyt-team/TopoNetX.git",
"lightning==2.4.0",
"gdown",
]

[project.optional-dependencies]
Expand Down
17 changes: 12 additions & 5 deletions test/data/load/test_datasetloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,18 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]:
job_name="run"
):
print('Current config file: ', config_file)
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"],
return_hydra_config=True,
)
if data_domain == "cell":
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat", "dataset.loader.parameters.num_data=24"],
return_hydra_config=True,
)
else:
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"],
return_hydra_config=True,
)
dataset_loader = hydra.utils.instantiate(parameters.dataset.loader)
print(repr(dataset_loader))

Expand Down
8 changes: 6 additions & 2 deletions test/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from test._utils.simplified_pipeline import run


DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE
MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE
DATASET = "cell/3D2M" # ADD YOUR DATASET HERE
MODELS = ["cell/cwn2"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE


class TestPipeline:
Expand All @@ -24,6 +24,10 @@ def test_pipeline(self):
overrides=[
f"model={MODEL}",
f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION
"dataset.loader.parameters.num_data=24",
"dataset.split_params.split_type=k-fold",
"dataset.split_params.k=2",
"dataset.dataloader_params.batch_size=1",
"trainer.max_epochs=2",
"trainer.min_epochs=1",
"trainer.check_val_every_n_epoch=1",
Expand Down
169 changes: 169 additions & 0 deletions topobench/data/datasets/d_3d2m_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
"""Dataset class for 3D2M dataset."""

import os
import os.path as osp
import shutil
from contextlib import suppress
from pathlib import Path
from typing import ClassVar

from omegaconf import DictConfig
from torch_geometric.data import Data, InMemoryDataset, extract_zip
from torch_geometric.io import fs

from topobench.data.utils import (
download_file_from_drive,
read_3d2m_meshes,
)


class D3D2MDataset(InMemoryDataset):
r"""Dataset class for 3D2M dataset.
Parameters
----------
root : str
Root directory where the dataset will be saved.
name : str
Name of the dataset.
parameters : DictConfig
Configuration parameters for the dataset.
Attributes
----------
URLS (dict): Dictionary containing the URLs for downloading the dataset.
FILE_FORMAT (dict): Dictionary containing the file formats for the dataset.
RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset.
"""

URLS: ClassVar = {
"3D2M": "https://drive.google.com/file/d/1jxVSmjDQmojh_5LHPLb9RmSIenWPiYnj/view?usp=drive_link",
}

FILE_FORMAT: ClassVar = {
"3D2M": "zip",
}

RAW_FILE_NAMES: ClassVar = {}

def __init__(
self,
root: str,
name: str,
parameters: DictConfig,
) -> None:
self.name = name
self.parameters = parameters
super().__init__(
root,
)

out = fs.torch_load(self.processed_paths[0])
assert len(out) == 3 or len(out) == 4

if len(out) == 3: # Backward compatibility.
data, self.slices, self.sizes = out
data_cls = Data
else:
data, self.slices, self.sizes, data_cls = out

if not isinstance(data, dict): # Backward compatibility.
self.data = data
else:
self.data = data_cls.from_dict(data)

assert isinstance(self._data, Data)

def __repr__(self) -> str:
return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.force_reload={self.force_reload})"

@property
def raw_dir(self) -> str:
"""Return the path to the raw directory of the dataset.
Returns
-------
str
Path to the raw directory.
"""
return osp.join(self.root, self.name, "raw")

@property
def processed_dir(self) -> str:
"""Return the path to the processed directory of the dataset.
Returns
-------
str
Path to the processed directory.
"""

return osp.join(self.root, self.name, "processed")

@property
def raw_file_names(self) -> list[str]:
"""Return the raw file names for the dataset.
Returns
-------
list[str]
List of raw file names.
"""
return ["*.obj", "*.npy"]

@property
def processed_file_names(self) -> str:
"""Return the processed file name for the dataset.
Returns
-------
str
Processed file name.
"""
return "data.pt"


def download(self) -> None:
r"""Download the dataset from a URL and saves it to the raw directory.
Raises:
FileNotFoundError: If the dataset URL is not found.
"""
# Download data from the source
self.url = self.URLS[self.name]
self.file_format = self.FILE_FORMAT[self.name]
download_file_from_drive(
file_link=self.url,
path_to_save=self.raw_dir,
dataset_name=self.name,
file_format=self.file_format,
)

# Extract zip file
folder = self.raw_dir
filename = f"{self.name}.{self.file_format}"
path = osp.join(folder, filename)
extract_zip(path, folder)
# Delete zip file
os.unlink(path)

# Move files' directories from osp.join(folder, name_download) to folder
data_folder_path = Path(osp.join(folder, self.name))
for subfolder in data_folder_path.iterdir():
with suppress(Exception):
subfolder.rename(Path(folder) / subfolder.name)
shutil.rmtree(osp.join(folder, self.name))

def process(self) -> None:
r"""Handle the data for the dataset.
This method loads the 3D2M data, applies any pre-
processing transformations if specified, and saves the processed data
to the appropriate location.
"""

# Step 1: Load raw data files
data_list = read_3d2m_meshes(self.raw_dir, self.parameters.num_data)

# Step 2: collate the graphs
self.data, self.slices = self.collate(data_list)
self._data_list = None # Reset cache.

# Step 3: save processed data
fs.torch_save(
(self._data.to_dict(), self.slices, {}, self._data.__class__),
self.processed_paths[0],
)

3 changes: 3 additions & 0 deletions topobench/data/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Init file for load module."""

from .base import AbstractLoader
from .cell import *
from .cell import __all__ as cell_all
from .graph import *
from .graph import __all__ as graph_all
from .hypergraph import *
Expand All @@ -12,6 +14,7 @@

__all__ = [
"AbstractLoader",
*cell_all,
*graph_all,
*hypergraph_all,
*simplicial_all,
Expand Down
Loading
Loading