Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions configs/dataset/graph/DAC.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Dataset loader config
loader:
_target_: topobench.data.loaders.DACDatasetLoader
parameters:
data_domain: graph
data_type: 4-325-1
data_name: 4-325-1
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_type}
split_num: 2

# Dataset parameters
parameters:
num_features: 2
num_classes: 8
num_nodes: 3224
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
task_level: graph


#splits
split_params:
learning_setting: inductive
# data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_split_dir: ${paths.data_dir}/${dataset.loader.parameters.data_name}/processed
data_seed: 0
split_type: fixed #'k-fold' # either "k-fold" or "random" strategies
k: 10 # for "k-fold" Cross-Validation
train_prop: 0.5 # for "random" strategy splitting
standardize: False

# Dataloader parameters
dataloader_params:
batch_size: 4495 # Fixed
num_workers: 0
pin_memory: False
6 changes: 3 additions & 3 deletions test/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Test pipeline for a particular dataset and model."""

import hydra
from test._utils.simplified_pipeline import run


DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE
MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE

DATASET = "graph/DAC" # ADD YOUR DATASET HERE
MODELS = ["graph/gcn"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE


class TestPipeline:
Expand Down
159 changes: 159 additions & 0 deletions topobench/data/datasets/dac_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Dataset class for Dynamic Activity Complex (DAC) dataset."""

import os
import os.path as osp
import shutil
from typing import ClassVar

import torch
from omegaconf import DictConfig
from torch_geometric.data import Data, InMemoryDataset, extract_zip

from topobench.data.utils import download_file_from_link


class DACDataset(InMemoryDataset):
r"""Dataset class for the Dynamic Activity Complexes (DAC) dataset.

Parameters
----------
root : str
Root directory where the dataset will be saved.
name : str
Name of the dataset.
parameters : DictConfig
Configuration parameters for the dataset.

Attributes
----------
URLS (dict): Dictionary containing the URLs for downloading the dataset.
"""

URLS: ClassVar = {
"4-325-1": "https://zenodo.org/records/17700425/files/4_325_1.zip",
"4-325-3": "https://zenodo.org/records/17700425/files/4_325_3.zip",
"4-325-5": "https://zenodo.org/records/17700425/files/4_325_5.zip",
}

def __init__(
self,
root: str,
name: str,
parameters: DictConfig,
):
# Load processed data (created in process())
self.name = name
super().__init__(root)
self.data, self.slices, self.splits = torch.load(
self.processed_paths[0]
)

split_num = parameters.split_num
self.split_idx = self.splits[split_num]

@property
def raw_file_names(self):
"""Return the raw file names for the dataset.

Returns
-------
list[str]
List of raw file names.
"""
return [
"all_edges.pt",
"all_x.pt",
"y.pt",
"split_0.pt",
"split_1.pt",
"split_2.pt",
"split_3.pt",
"split_4.pt",
]

@property
def processed_file_names(self):
"""Return the processed file name for the dataset.

Returns
-------
str
Processed file name.
"""
return ["data.pt"]

@property
def processed_dir(self) -> str:
"""Return the path to the processed directory of the dataset.

Returns
-------
str
Path to the processed directory.
"""
self.processed_root = osp.join(self.root)
return osp.join(self.processed_root, "processed")

def download(self):
r"""Download the dataset from a URL and saves it to the raw directory.

Raises:
FileNotFoundError: If the dataset URL is not found.
"""
# Step 1: Download data from the source
self.url = self.URLS[self.name]
download_file_from_link(
file_link=self.url,
path_to_save=self.raw_dir,
dataset_name=self.name,
file_format="zip",
)

# Step 2: extract zip file
folder = self.raw_dir
filename = f"{self.name}.zip"
path = osp.join(folder, filename)
extract_zip(path, folder)
# Delete zip file
os.unlink(path)

# Step 3: organize files
# Move files from osp.join(folder, name_download) to folder
folder_name = "4_325_" + self.name.split("-")[2]
for file in os.listdir(osp.join(folder, folder_name)):
shutil.move(osp.join(folder, folder_name, file), folder)
# Delete osp.join(folder, self.name) dir
shutil.rmtree(osp.join(folder, folder_name))

def process(self):
r"""Handle the data for the dataset.

This method loads the DAC raw data, creates one object for
each graph, and saves the processed data
to the appropriate location.
"""
# Load raw tensors
relations = torch.load(os.path.join(self.raw_dir, "all_edges.pt"))
all_x = torch.load(os.path.join(self.raw_dir, "all_x.pt"))
y = torch.load(os.path.join(self.raw_dir, "y.pt"))

data_list = []
for i in range(len(all_x)):
# Create PyG Data object
data = Data(
x=all_x[i],
edge_index=relations[i],
y=y[i].unsqueeze(0) if y[i].ndim == 0 else y[i],
)

data_list.append(data)

# Save to processed dir using slicing format
data, slices = self.collate(data_list)

splits = []
for s in range(5):
split = torch.load(os.path.join(self.raw_dir, f"split_{s}.pt"))
splits.append(split)

torch.save((data, slices, splits), self.processed_paths[0])
98 changes: 98 additions & 0 deletions topobench/data/loaders/combinatorial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Init file for combinatorial dataset load module with automated loader discovery."""

import inspect
from importlib import util
from pathlib import Path
from typing import Any, ClassVar


class CombinatorialLoaderManager:
"""Manages automatic discovery and registration of combinatorial dataset loader classes."""

# Base class that all combinatorial loaders should inherit from (adjust based on your actual base class)
BASE_LOADER_CLASS: ClassVar[type] = object

@staticmethod
def is_loader_class(obj: Any) -> bool:
"""Check if an object is a valid combinatorial dataset loader class.

Parameters
----------
obj : Any
The object to check if it's a valid combinatorial dataset loader class.

Returns
-------
bool
True if the object is a valid combinatorial dataset loader class (non-private class
with 'DatasetLoader' in name), False otherwise.
"""
return (
inspect.isclass(obj)
and not obj.__name__.startswith("_")
and "DatasetLoader" in obj.__name__
)

@classmethod
def discover_loaders(cls, package_path: str) -> dict[str, type[Any]]:
"""Dynamically discover all combinatorial dataset loader classes in the package.

Parameters
----------
package_path : str
Path to the package's __init__.py file.

Returns
-------
Dict[str, Type[Any]]
Dictionary mapping loader class names to their corresponding class objects.
"""
loaders = {}

# Get the directory containing the loader modules
package_dir = Path(package_path).parent

# Iterate through all .py files in the directory
for file_path in package_dir.glob("*.py"):
if file_path.stem == "__init__":
continue

# Import the module
module_name = f"{Path(package_path).stem}.{file_path.stem}"
spec = util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find all combinatorial dataset loader classes in the module
new_loaders = {
name: obj
for name, obj in inspect.getmembers(module)
if (
cls.is_loader_class(obj)
and obj.__module__ == module.__name__
)
}
loaders.update(new_loaders)
return loaders


# Create the loader manager
manager = CombinatorialLoaderManager()

# Automatically discover and populate loaders
COMBINATORIAL_LOADERS = manager.discover_loaders(__file__)

COMBINATORIAL_LOADERS_list = list(COMBINATORIAL_LOADERS.keys())

# Automatically generate __all__
__all__ = [
# Loader collections
"COMBINATORIAL_LOADERS",
"COMBINATORIAL_LOADERS_list",
# Individual loader classes
*COMBINATORIAL_LOADERS.keys(),
]

# For backwards compatibility, create individual imports
locals().update(**COMBINATORIAL_LOADERS)
70 changes: 70 additions & 0 deletions topobench/data/loaders/combinatorial/dac_dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Loaders for Mantra dataset as simplicial."""

from omegaconf import DictConfig

from topobench.data.datasets import DACDataset
from topobench.data.loaders.base import AbstractLoader


class DACCombinatorialDatasetLoader(AbstractLoader):
"""Load Mantra dataset with configurable parameters.

Note: for the simplicial datasets it is necessary to include DatasetLoader into the name of the class!

Parameters
----------
parameters : DictConfig
Configuration parameters containing:
- data_dir: Root directory for data
- data_name: Name of the dataset
- other relevant parameters

**kwargs : dict
Additional keyword arguments.
"""

def __init__(self, parameters: DictConfig, **kwargs) -> None:
super().__init__(parameters, **kwargs)

def load_dataset(self, **kwargs) -> DACDataset:
"""Load the DAC Combinatorial dataset.

Parameters
----------
**kwargs : dict
Additional keyword arguments for dataset initialization.

Returns
-------
DACCombinatorialDataset
The loaded DAC Combinatorial dataset with the appropriate `data_dir`.

Raises
------
RuntimeError
If dataset loading fails.
"""

dataset = self._initialize_dataset(**kwargs)
self.data_dir = self.get_data_dir()
return dataset

def _initialize_dataset(self, **kwargs) -> DACDataset:
"""Initialize the Citation Hypergraph dataset.

Parameters
----------
**kwargs : dict
Additional keyword arguments for dataset initialization.

Returns
-------
CitationHypergraphDataset
The initialized dataset instance.
"""
return DACDataset(
root=str(self.root_data_dir),
name=self.parameters.data_name,
parameters=self.parameters,
**kwargs,
)
Loading