From b92ef0eb098492654625adffe60181ab8bcbdf2a Mon Sep 17 00:00:00 2001 From: jableable Date: Tue, 25 Nov 2025 10:05:35 -0600 Subject: [PATCH 1/2] Add Metamath dataset, loader, and pipeline test --- configs/dataset/graph/metamath.yaml | 30 +++ test/pipeline/test_pipeline.py | 28 +-- topobench/data/datasets/metamath_dataset.py | 198 ++++++++++++++++++ .../loaders/graph/metamath_dataset_loader.py | 33 +++ 4 files changed, 275 insertions(+), 14 deletions(-) create mode 100644 configs/dataset/graph/metamath.yaml create mode 100644 topobench/data/datasets/metamath_dataset.py create mode 100644 topobench/data/loaders/graph/metamath_dataset_loader.py diff --git a/configs/dataset/graph/metamath.yaml b/configs/dataset/graph/metamath.yaml new file mode 100644 index 000000000..be2340ae7 --- /dev/null +++ b/configs/dataset/graph/metamath.yaml @@ -0,0 +1,30 @@ +loader: + _target_: topobench.data.loaders.graph.MetamathDatasetLoader + parameters: + data_domain: graph + data_name: metamath + data_dir: ${paths.data_dir}/graph/metamath + + mask_target_node: false # already zeroed it in preprocessing + mask_mode: "zero" + + +parameters: + num_features: 768 + num_classes: 3557 + task: classification + task_level: node + loss_type: cross_entropy + monitor_metric: accuracy # large number of unevenly distributed classes; + # f1 more appropriate but unavailable + +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_splits_dir}/metamath + split_type: "fixed" # use your precomputed splits + standardize: false + +dataloader_params: + batch_size: 512 + num_workers: 0 + pin_memory: false diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..cdae08d0f 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -1,35 +1,35 @@ -"""Test pipeline for a particular dataset and model.""" +"""Test pipeline for the Metamath dataset and a simple GNN model.""" import hydra from test._utils.simplified_pipeline import run +from hydra.core.global_hydra import GlobalHydra - -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE +# Your dataset + a simple graph model from TopoBench +DATASET = "graph/metamath" +MODELS = ["graph/gcn"] # could also try ["graph/gin"] if that config exists class TestPipeline: - """Test pipeline for a particular dataset and model.""" + """End-to-end pipeline test for Metamath.""" def setup_method(self): - """Setup method.""" - hydra.core.global_hydra.GlobalHydra.instance().clear() - + """Reset Hydra between tests.""" + GlobalHydra.instance().clear() + def test_pipeline(self): - """Test pipeline.""" - with hydra.initialize(config_path="../../configs", job_name="job"): + """Run a very short training job and ensure it completes.""" + with hydra.initialize(config_path="../../configs", job_name="metamath_test"): for MODEL in MODELS: cfg = hydra.compose( config_name="run.yaml", overrides=[ f"model={MODEL}", - f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION + f"dataset={DATASET}", "trainer.max_epochs=2", "trainer.min_epochs=1", "trainer.check_val_every_n_epoch=1", "paths=test", - "callbacks=model_checkpoint", + "callbacks=model_checkpoint", ], - return_hydra_config=True + return_hydra_config=True, ) - run(cfg) \ No newline at end of file diff --git a/topobench/data/datasets/metamath_dataset.py b/topobench/data/datasets/metamath_dataset.py new file mode 100644 index 000000000..018057850 --- /dev/null +++ b/topobench/data/datasets/metamath_dataset.py @@ -0,0 +1,198 @@ +# metamath_dataset.py + +import os +import os.path as osp +from typing import ClassVar + +import numpy as np +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.io import fs + +from topobench.data.utils import download_file_from_link + + +class MetamathDataset(InMemoryDataset): + """ + Metamath proof graph dataset backed by a precomputed data.pt file. + + The Hugging Face data.pt is expected to have the form: + + { + "data": data, # PyG Data object from collate(...) + "slices": slices, # slices dict from collate(...) + "train_idx": ..., # 1D indices of train graphs + "val_idx": ..., # 1D indices of val graphs + "test_idx": ..., # 1D indices of test graphs + } + + This class simply: + - downloads data.pt from HF (into raw_dir), + - copies it into processed_dir, + - loads it and exposes: + * self.data, self.slices + * self.split_idx = {"train", "valid", "test"} + """ + + HF_BASE: ClassVar[str] = "https://huggingface.co/datasets" + HF_REPO: ClassVar[str] = "jableable/metamath-proof-graphs" + HF_FILENAME: ClassVar[str] = "data.pt" + + def __init__(self, root: str, name: str, parameters) -> None: + self.name = name + self.parameters = parameters + + super().__init__(root) + + out = fs.torch_load(self.processed_paths[0]) + + if not isinstance(out, dict): + raise TypeError( + f"Expected dict in {self.processed_paths[0]}, got {type(out)}" + ) + + data = out["data"] + self.slices = out["slices"] + + # Rebuild Data from dict if needed + if isinstance(data, dict): + data = Data.from_dict(data) + + self.data = data + + # Expose fixed splits for TopoBench + train_idx = out.get("train_idx", None) + val_idx = out.get("val_idx", None) + test_idx = out.get("test_idx", None) + + if ( + train_idx is not None + and val_idx is not None + and test_idx is not None + ): + # Convert to numpy arrays for split_utils + if isinstance(train_idx, torch.Tensor): + train_idx = train_idx.cpu().numpy() + val_idx = val_idx.cpu().numpy() + test_idx = test_idx.cpu().numpy() + + self.split_idx = { + "train": np.array(train_idx, dtype=int), + "valid": np.array(val_idx, dtype=int), + "test": np.array(test_idx, dtype=int), + } + + # ------------------------------------------------------------------------- + # Directory layout + # ------------------------------------------------------------------------- + + @property + def raw_dir(self) -> str: + # //raw + return osp.join(self.root, self.name, "raw") + + @property + def processed_dir(self) -> str: + # //processed + return osp.join(self.root, self.name, "processed") + + # ------------------------------------------------------------------------- + # File naming + # ------------------------------------------------------------------------- + + @property + def raw_file_names(self) -> list[str]: + # We only expect a single raw artifact: data.pt + return [self.HF_FILENAME] + + @property + def processed_file_names(self) -> str: + # Single processed file, also called data.pt + return "data.pt" + + # ------------------------------------------------------------------------- + # Download from Hugging Face + # ------------------------------------------------------------------------- + + def download(self) -> None: + """ + Download data.pt from Hugging Face into raw_dir. + + Expected HF layout: + https://huggingface.co/datasets/jableable/metamath-proof-graphs/resolve/main/data/data.pt + """ + os.makedirs(self.raw_dir, exist_ok=True) + + url = f"{self.HF_BASE}/{self.HF_REPO}/resolve/main/data/{self.HF_FILENAME}" + dataset_name, file_format = os.path.splitext(self.HF_FILENAME) + file_format = file_format.lstrip(".") + + download_file_from_link( + file_link=url, + path_to_save=self.raw_dir, + dataset_name=dataset_name, + file_format=file_format, + ) + + # ------------------------------------------------------------------------- + # Process: copy / normalize the HF data.pt to processed_dir + # ------------------------------------------------------------------------- + + def process(self) -> None: + """Load raw data.pt, fix dtypes, and save processed data.pt as a dict.""" + raw_pt = osp.join(self.raw_dir, "data.pt") + obj = torch.load(raw_pt, weights_only=False) + + raw_data = obj["data"] + raw_slices = obj["slices"] + train_idx = obj["train_idx"] + val_idx = obj["val_idx"] + test_idx = obj["test_idx"] + + # Temporary dataset to reconstruct individual graphs + class _Tmp(InMemoryDataset): + def __init__(self, data, slices): + super().__init__(".") + self.data = data + self.slices = slices + + def _download(self): + pass + + def _process(self): + pass + + tmp = _Tmp(raw_data, raw_slices) + + graphs = [] + for i in range(len(tmp)): + g = tmp[i] + + # 🔧 Critical fix: ensure edge_index is integer + if hasattr(g, "edge_index"): + g.edge_index = g.edge_index.long() + + graphs.append(g) + + # Re-collate into a clean storage + data_fixed, slices_fixed = tmp.collate(graphs) + + out = { + "data": data_fixed, + "slices": slices_fixed, + "train_idx": train_idx, + "val_idx": val_idx, + "test_idx": test_idx, + } + + fs.torch_save(out, self.processed_paths[0]) + + # ------------------------------------------------------------------------- + + def __repr__(self) -> str: + return ( + f"MetamathDataset(root={self.root}, name={self.name}, " + f"num_graphs={len(self)}, " + f"has_split_idx={'split_idx' in self.__dict__})" + ) diff --git a/topobench/data/loaders/graph/metamath_dataset_loader.py b/topobench/data/loaders/graph/metamath_dataset_loader.py new file mode 100644 index 000000000..a7e1c1630 --- /dev/null +++ b/topobench/data/loaders/graph/metamath_dataset_loader.py @@ -0,0 +1,33 @@ +"""Loader for Metamath proof graph dataset.""" + +from pathlib import Path +from omegaconf import DictConfig +from torch_geometric.data import Dataset + +from topobench.data.datasets import MetamathDataset +from topobench.data.loaders.base import AbstractLoader + + +class MetamathDatasetLoader(AbstractLoader): + """Thin wrapper around MetamathDataset for TopoBench.""" + + def __init__(self, parameters: DictConfig) -> None: + super().__init__(parameters) + + def load_dataset(self) -> Dataset: + """ + Initialize the MetamathDataset and expose its processed_dir + via self.data_dir (for split utils / logging). + """ + dataset = MetamathDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.parameters, + ) + + # Point data_dir to processed folder for downstream utilities + self.data_dir = Path(dataset.processed_dir) + + # No label collapsing or masking here; dataset is ready to use. + # Splits are handled via dataset.split_idx + split_utils (fixed split). + return dataset From 7ef598fb4cbe8cec5cf3f164f31498a91add8e22 Mon Sep 17 00:00:00 2001 From: jableable Date: Tue, 25 Nov 2025 11:18:53 -0600 Subject: [PATCH 2/2] Fix lint issues --- topobench/data/datasets/metamath_dataset.py | 1 - topobench/data/loaders/graph/metamath_dataset_loader.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/topobench/data/datasets/metamath_dataset.py b/topobench/data/datasets/metamath_dataset.py index 018057850..a9a1ed2a5 100644 --- a/topobench/data/datasets/metamath_dataset.py +++ b/topobench/data/datasets/metamath_dataset.py @@ -6,7 +6,6 @@ import numpy as np import torch -from omegaconf import DictConfig from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs diff --git a/topobench/data/loaders/graph/metamath_dataset_loader.py b/topobench/data/loaders/graph/metamath_dataset_loader.py index a7e1c1630..671619986 100644 --- a/topobench/data/loaders/graph/metamath_dataset_loader.py +++ b/topobench/data/loaders/graph/metamath_dataset_loader.py @@ -1,6 +1,7 @@ """Loader for Metamath proof graph dataset.""" from pathlib import Path + from omegaconf import DictConfig from torch_geometric.data import Dataset