Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions configs/dataset/pointcloud/semantic.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Semantic Dataset Configuration

# Dataset loader config
loader:
_target_: topobench.data.loaders.SemanticDatasetLoader
parameters:
n_subsampling: 1000
data_domain: pointcloud
data_type: semantic
data_name: cifar10
data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type}
models:
- aimv2_1b_patch14_224.apple_pt
# - aimv2_1b_patch14_336.apple_pt
# - aimv2_1b_patch14_448.apple_pt

# Dataset parameters
parameters:
data_name: cifar10
num_features: 2048
num_nodes: 1000
num_classes: 10
loss_type: cross_entropy
monitor_metric: accuracy
task: classification
task_level: node
# Lifting parameters
max_dim_if_lifted: 3 # This is the maximum dimension of the simplicial complex in the dataset
preserve_edge_attr_if_lifted: True

#splits
split_params:
learning_setting: transductive
data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name}
data_seed: 0
split_type: random
k: 10
train_prop: 0.5

# Dataloader parameters
dataloader_params:
batch_size: 1
num_workers: 1
pin_memory: False
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies=[
"topomodelx @ git+https://github.com/pyt-team/TopoModelX.git",
"toponetx @ git+https://github.com/pyt-team/TopoNetX.git",
"lightning==2.4.0",
"datasets==4.4.1",
]

[project.optional-dependencies]
Expand Down
4 changes: 2 additions & 2 deletions test/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from test._utils.simplified_pipeline import run


DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE
MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE
DATASET = "pointcloud/semantic" # ADD YOUR DATASET HERE
MODELS = ["non_relational/mlp"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE


class TestPipeline:
Expand Down
178 changes: 178 additions & 0 deletions topobench/data/datasets/semantic_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Dataset class for Semantic Representations."""

import os.path as osp

import torch
from datasets import concatenate_datasets, load_dataset
from omegaconf import DictConfig
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
from tqdm.auto import tqdm


class SemanticDataset(InMemoryDataset):
r"""Dataset class for semantic representation of real datasets.

Parameters
----------
root : str
Root directory where the dataset will be saved.
parameters : DictConfig
Configuration parameters for the dataset.
"""

parameters = DictConfig(
{
"data_name": None,
"n_subsampling": None,
"models": None,
}
)

def __init__(
self,
root: str,
parameters: DictConfig,
) -> None:
print(parameters)
self.parameters.update(parameters)

# Unpack parameters
self.data_name = self.parameters.data_name
self.models = self.parameters.models
self.n_subsampling = self.parameters.n_subsampling

assert self.data_name is not None, (
'The "data_name" parameter must be set to a dataset name.'
)
assert self.models is not None, (
'The "models" parameter must be set to a list of model names.'
)

# HF Repository URL
self.url: str = f"spaicom-lab/semantic-{self.data_name}"

# Call the super init
super().__init__(
root,
)

out = fs.torch_load(self.processed_paths[0])

data, self.slices, self.sizes, data_cls = out

self.data = data_cls.from_dict(data)
# print(f"{self.sizes=}")
# if self.n_subsampling is not None:
# data_list = []
# for i in range(len(self.models)):
# d = self.get(i)
# n = min(self.n_subsampling, d.x.size(0))
# d.x = d.x[:n]
# d.y = d.y[:n]
# data_list.append(d)

# self.data, self.slices = self.collate(data_list)
assert isinstance(self._data, Data)

def __repr__(self) -> str:
return f"{self.data_name}(self.parameters={self.parameters}, self.force_reload={self.force_reload})"

@property
def raw_dir(self) -> str:
"""Return the path to the raw directory of the dataset.

Returns
-------
str
Path to the raw directory.
"""
return osp.join(self.root, self.data_name, "raw")

@property
def raw_file_names(self) -> list[str]:
"""Return the raw file names for the dataset.

Returns
-------
list[str]
List of raw file names.
"""
return []

@property
def processed_dir(self) -> str:
"""Return the path to the processed directory of the dataset.

Returns
-------
str
Path to the processed directory.
"""
return osp.join(self.root, self.data_name, "processed")

@property
def processed_file_names(self) -> str:
"""Return the processed file name for the dataset.

Returns
-------
str
Processed file name.
"""
return "data.pt"

def download(self) -> None:
r"""Download the dataset from the HF repository."""
# Step 1: Download data from the source
for model in tqdm(self.models, desc="Loading Models"):
load_dataset(self.url, model, cache_dir=self.raw_dir)

def process(self) -> None:
r"""Handle the data for the dataset.

This method loads the semantic representation data, and saves the processed data
to the appropriate location.
"""
data_list = []
for model in tqdm(self.models, desc="Preprocessing Models"):
dataset = load_dataset(self.url, model, cache_dir=self.raw_dir)

dataset = concatenate_datasets([dataset["train"], dataset["test"]])
if self.n_subsampling is not None:
n_sub = min(self.n_subsampling, len(dataset))
dataset = dataset.select(range(n_sub))

label = torch.tensor(dataset["label"])
embedding = torch.tensor(dataset["embedding"])

data = Data(x=embedding, y=label)
data.model = model

data_list.append(data)

self.data, self.slices = self.collate(data_list)
self._data_list = None # Reset cache.
self._data.dataset = self.data_name
fs.torch_save(
(self._data.to_dict(), self.slices, {}, self._data.__class__),
self.processed_paths[0],
)


if __name__ == "__main__":
# Some Variables
root: str = "example"
parameters: dict[str] = DictConfig(
{
"data_name": "cifar10",
"models": [
"aimv2_1b_patch14_224.apple_pt",
"aimv2_1b_patch14_336.apple_pt",
"aimv2_1b_patch14_448.apple_pt",
],
}
)

# Initialize a Semantic Dataset
dataset = SemanticDataset(root=root, parameters=parameters)
76 changes: 76 additions & 0 deletions topobench/data/loaders/pointcloud/semantic_dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Loaders for Semantic representation dataset."""

from omegaconf import DictConfig

from topobench.data.datasets import SemanticDataset
from topobench.data.loaders.base import AbstractLoader


class SemanticDatasetLoader(AbstractLoader):
"""Load Semantic representations dataset with configurable parameters.

Parameters
----------
parameters : DictConfig
Configuration parameters containing:
- data_dir: Root directory for data.
- data_name: Name of the dataset.
- models: a list of (neural) model names used to build semantic representations.
- other relevant parameters.
"""

def __init__(
self,
parameters: DictConfig,
) -> None:
super().__init__(parameters)
self.parameters = parameters

def load_dataset(self) -> SemanticDataset:
"""Load the Semantic dataset.

Returns
-------
SemanticDataset
The loaded a Semantic dataset with the appropriate `name`.

Raises
------
RuntimeError
If dataset loading fails.
"""

dataset = self._initialize_dataset()
self.data_dir = self.get_data_dir()
return dataset

def _initialize_dataset(self) -> SemanticDataset:
"""Initialize the Semantic dataset.

Returns
-------
SemanticDataset
The initialized dataset instance.
"""
return SemanticDataset(
root=str(self.root_data_dir),
parameters=self.parameters,
)


if __name__ == "__main__":
# Some Variables
parameters: dict[str] = DictConfig(
{
"data_dir": "example",
"data_name": "cifar10",
"models": [
"aimv2_1b_patch14_224.apple_pt",
"aimv2_1b_patch14_336.apple_pt",
"aimv2_1b_patch14_448.apple_pt",
],
}
)

# Initialize a Semantic Dataset
dataloader = SemanticDatasetLoader(parameters=parameters)
Loading