Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions configs/dataset/graph/metamath.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
loader:
_target_: topobench.data.loaders.graph.MetamathDatasetLoader
parameters:
data_domain: graph
data_name: metamath
data_dir: ${paths.data_dir}/graph/metamath

mask_target_node: false # already zeroed it in preprocessing
mask_mode: "zero"


parameters:
num_features: 768
num_classes: 3557
task: classification
task_level: node
loss_type: cross_entropy
monitor_metric: accuracy # large number of unevenly distributed classes;
# f1 more appropriate but unavailable

split_params:
learning_setting: inductive
data_split_dir: ${paths.data_splits_dir}/metamath
split_type: "fixed" # use your precomputed splits
standardize: false

dataloader_params:
batch_size: 512
num_workers: 0
pin_memory: false
28 changes: 14 additions & 14 deletions test/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
"""Test pipeline for a particular dataset and model."""
"""Test pipeline for the Metamath dataset and a simple GNN model."""

import hydra
from test._utils.simplified_pipeline import run
from hydra.core.global_hydra import GlobalHydra


DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE
MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE
# Your dataset + a simple graph model from TopoBench
DATASET = "graph/metamath"
MODELS = ["graph/gcn"] # could also try ["graph/gin"] if that config exists


class TestPipeline:
"""Test pipeline for a particular dataset and model."""
"""End-to-end pipeline test for Metamath."""

def setup_method(self):
"""Setup method."""
hydra.core.global_hydra.GlobalHydra.instance().clear()
"""Reset Hydra between tests."""
GlobalHydra.instance().clear()

def test_pipeline(self):
"""Test pipeline."""
with hydra.initialize(config_path="../../configs", job_name="job"):
"""Run a very short training job and ensure it completes."""
with hydra.initialize(config_path="../../configs", job_name="metamath_test"):
for MODEL in MODELS:
cfg = hydra.compose(
config_name="run.yaml",
overrides=[
f"model={MODEL}",
f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION
f"dataset={DATASET}",
"trainer.max_epochs=2",
"trainer.min_epochs=1",
"trainer.check_val_every_n_epoch=1",
"paths=test",
"callbacks=model_checkpoint",
"callbacks=model_checkpoint",
],
return_hydra_config=True
return_hydra_config=True,
)
run(cfg)
197 changes: 197 additions & 0 deletions topobench/data/datasets/metamath_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# metamath_dataset.py

import os
import os.path as osp
from typing import ClassVar

import numpy as np
import torch
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs

from topobench.data.utils import download_file_from_link


class MetamathDataset(InMemoryDataset):
"""
Metamath proof graph dataset backed by a precomputed data.pt file.

The Hugging Face data.pt is expected to have the form:

{
"data": data, # PyG Data object from collate(...)
"slices": slices, # slices dict from collate(...)
"train_idx": ..., # 1D indices of train graphs
"val_idx": ..., # 1D indices of val graphs
"test_idx": ..., # 1D indices of test graphs
}

This class simply:
- downloads data.pt from HF (into raw_dir),
- copies it into processed_dir,
- loads it and exposes:
* self.data, self.slices
* self.split_idx = {"train", "valid", "test"}
"""

HF_BASE: ClassVar[str] = "https://huggingface.co/datasets"
HF_REPO: ClassVar[str] = "jableable/metamath-proof-graphs"
HF_FILENAME: ClassVar[str] = "data.pt"

def __init__(self, root: str, name: str, parameters) -> None:
self.name = name
self.parameters = parameters

super().__init__(root)

out = fs.torch_load(self.processed_paths[0])

if not isinstance(out, dict):
raise TypeError(
f"Expected dict in {self.processed_paths[0]}, got {type(out)}"
)

data = out["data"]
self.slices = out["slices"]

# Rebuild Data from dict if needed
if isinstance(data, dict):
data = Data.from_dict(data)

self.data = data

# Expose fixed splits for TopoBench
train_idx = out.get("train_idx", None)
val_idx = out.get("val_idx", None)
test_idx = out.get("test_idx", None)

if (
train_idx is not None
and val_idx is not None
and test_idx is not None
):
# Convert to numpy arrays for split_utils
if isinstance(train_idx, torch.Tensor):
train_idx = train_idx.cpu().numpy()
val_idx = val_idx.cpu().numpy()
test_idx = test_idx.cpu().numpy()

self.split_idx = {
"train": np.array(train_idx, dtype=int),
"valid": np.array(val_idx, dtype=int),
"test": np.array(test_idx, dtype=int),
}

# -------------------------------------------------------------------------
# Directory layout
# -------------------------------------------------------------------------

@property
def raw_dir(self) -> str:
# <root>/<name>/raw
return osp.join(self.root, self.name, "raw")

@property
def processed_dir(self) -> str:
# <root>/<name>/processed
return osp.join(self.root, self.name, "processed")

# -------------------------------------------------------------------------
# File naming
# -------------------------------------------------------------------------

@property
def raw_file_names(self) -> list[str]:
# We only expect a single raw artifact: data.pt
return [self.HF_FILENAME]

@property
def processed_file_names(self) -> str:
# Single processed file, also called data.pt
return "data.pt"

# -------------------------------------------------------------------------
# Download from Hugging Face
# -------------------------------------------------------------------------

def download(self) -> None:
"""
Download data.pt from Hugging Face into raw_dir.

Expected HF layout:
https://huggingface.co/datasets/jableable/metamath-proof-graphs/resolve/main/data/data.pt
"""
os.makedirs(self.raw_dir, exist_ok=True)

url = f"{self.HF_BASE}/{self.HF_REPO}/resolve/main/data/{self.HF_FILENAME}"
dataset_name, file_format = os.path.splitext(self.HF_FILENAME)
file_format = file_format.lstrip(".")

download_file_from_link(
file_link=url,
path_to_save=self.raw_dir,
dataset_name=dataset_name,
file_format=file_format,
)

# -------------------------------------------------------------------------
# Process: copy / normalize the HF data.pt to processed_dir
# -------------------------------------------------------------------------

def process(self) -> None:
"""Load raw data.pt, fix dtypes, and save processed data.pt as a dict."""
raw_pt = osp.join(self.raw_dir, "data.pt")
obj = torch.load(raw_pt, weights_only=False)

raw_data = obj["data"]
raw_slices = obj["slices"]
train_idx = obj["train_idx"]
val_idx = obj["val_idx"]
test_idx = obj["test_idx"]

# Temporary dataset to reconstruct individual graphs
class _Tmp(InMemoryDataset):
def __init__(self, data, slices):
super().__init__(".")
self.data = data
self.slices = slices

def _download(self):
pass

def _process(self):
pass

tmp = _Tmp(raw_data, raw_slices)

graphs = []
for i in range(len(tmp)):
g = tmp[i]

# 🔧 Critical fix: ensure edge_index is integer
if hasattr(g, "edge_index"):
g.edge_index = g.edge_index.long()

graphs.append(g)

# Re-collate into a clean storage
data_fixed, slices_fixed = tmp.collate(graphs)

out = {
"data": data_fixed,
"slices": slices_fixed,
"train_idx": train_idx,
"val_idx": val_idx,
"test_idx": test_idx,
}

fs.torch_save(out, self.processed_paths[0])

# -------------------------------------------------------------------------

def __repr__(self) -> str:
return (
f"MetamathDataset(root={self.root}, name={self.name}, "
f"num_graphs={len(self)}, "
f"has_split_idx={'split_idx' in self.__dict__})"
)
34 changes: 34 additions & 0 deletions topobench/data/loaders/graph/metamath_dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Loader for Metamath proof graph dataset."""

from pathlib import Path

from omegaconf import DictConfig
from torch_geometric.data import Dataset

from topobench.data.datasets import MetamathDataset
from topobench.data.loaders.base import AbstractLoader


class MetamathDatasetLoader(AbstractLoader):
"""Thin wrapper around MetamathDataset for TopoBench."""

def __init__(self, parameters: DictConfig) -> None:
super().__init__(parameters)

def load_dataset(self) -> Dataset:
"""
Initialize the MetamathDataset and expose its processed_dir
via self.data_dir (for split utils / logging).
"""
dataset = MetamathDataset(
root=str(self.root_data_dir),
name=self.parameters.data_name,
parameters=self.parameters,
)

# Point data_dir to processed folder for downstream utilities
self.data_dir = Path(dataset.processed_dir)

# No label collapsing or masking here; dataset is ready to use.
# Splits are handled via dataset.split_idx + split_utils (fixed split).
return dataset