Skip to content
60 changes: 60 additions & 0 deletions test/data/core/test_qm9_robust_on_disk_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os
import torch
import pytest
from torch_geometric.data import Data
from topobench.data.core.qm9_robust_on_disk_dataset import QM9RobustOnDiskDataset


@pytest.fixture
def tmp_dataset(tmp_path, monkeypatch):
"""
Fixture creating 20 synthetic samples and monkeypatching QM9.
"""

# Fake QM9 for testing
class FakeQM9:
raw_file_names = [f"mol_{i}.xyz" for i in range(20)]

def __init__(self, root):
self.root = root
os.makedirs(os.path.join(root, "raw"), exist_ok=True)
self.raw_paths = []
self.processed_paths = []
for name in self.raw_file_names:
path = os.path.join(root, "raw", name)
tensor = torch.rand(5, 11)
torch.save(tensor, path)
self.raw_paths.append(path)
self.processed_paths.append(path)

# Monkeypatch QM9
monkeypatch.setattr(
"topobench.data.core.qm9_robust_on_disk_dataset.QM9", lambda root: FakeQM9(root)
)

dataset = QM9RobustOnDiskDataset(root=str(tmp_path / "qm9"), chunk_size=5)
return dataset


def test_dataset_length(tmp_dataset):
assert tmp_dataset.len() == 20


def test_dataset_get(tmp_dataset):
for i in range(20):
data = tmp_dataset.get(i)
assert isinstance(data, Data)
assert data.x.shape == (5, 11)


def test_chunks_are_saved(tmp_dataset):
chunk_files = [
f for f in os.listdir(tmp_dataset.processed_dir)
if f.startswith("chunk_") and f.endswith(".pt")
]
assert len(chunk_files) == 4 # 20 samples / chunk_size=5


def test_index_out_of_range(tmp_dataset):
with pytest.raises(IndexError):
tmp_dataset.get(999)
52 changes: 52 additions & 0 deletions test/data/core/test_qm9_robust_on_disk_dataset_1000.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import torch
import pytest
from torch_geometric.data import Data
from topobench.data.core.qm9_robust_on_disk_dataset import QM9RobustOnDiskDataset


@pytest.fixture
def tmp_dataset_1000(tmp_path, monkeypatch):
"""
Fixture creating 1000 synthetic samples with chunk_size=50.
"""

class FakeQM9:
raw_file_names = [f"mol_{i}.xyz" for i in range(1000)]

def __init__(self, root):
self.root = root
os.makedirs(os.path.join(root, "raw"), exist_ok=True)
self.raw_paths = []
for name in self.raw_file_names:
path = os.path.join(root, "raw", name)
tensor = torch.rand(5, 11)
torch.save(tensor, path)
self.raw_paths.append(path)

monkeypatch.setattr(
"topobench.data.core.qm9_robust_on_disk_dataset.QM9", lambda root: FakeQM9(root)
)

dataset = QM9RobustOnDiskDataset(root=str(tmp_path / "qm9_1000"), chunk_size=50)
return dataset


def test_dataset_length_1000(tmp_dataset_1000):
assert tmp_dataset_1000.len() == 1000


def test_chunks_are_saved_1000(tmp_dataset_1000):
chunk_files = [
f for f in os.listdir(tmp_dataset_1000.processed_dir)
if f.startswith("chunk_") and f.endswith(".pt")
]
# 1000 / 50 = 20 chunks
assert len(chunk_files) == 20


def test_dataset_get_1000(tmp_dataset_1000):
for i in range(0, 1000, 100): # test every 100th sample for speed
data = tmp_dataset_1000.get(i)
assert isinstance(data, Data)
assert data.x.shape == (5, 11)
52 changes: 52 additions & 0 deletions test/data/core/test_qm9_robust_on_disk_dataset_10000.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import torch
import pytest
from torch_geometric.data import Data
from topobench.data.core.qm9_robust_on_disk_dataset import QM9RobustOnDiskDataset


@pytest.fixture
def tmp_dataset_10000(tmp_path, monkeypatch):
"""
Fixture creating 10000 synthetic samples with chunk_size=500.
"""

class FakeQM9:
raw_file_names = [f"mol_{i}.xyz" for i in range(10000)]

def __init__(self, root):
self.root = root
os.makedirs(os.path.join(root, "raw"), exist_ok=True)
self.raw_paths = []
for name in self.raw_file_names:
path = os.path.join(root, "raw", name)
tensor = torch.rand(5, 11)
torch.save(tensor, path)
self.raw_paths.append(path)

monkeypatch.setattr(
"topobench.data.core.qm9_robust_on_disk_dataset.QM9", lambda root: FakeQM9(root)
)

dataset = QM9RobustOnDiskDataset(root=str(tmp_path / "qm9_10000"), chunk_size=500)
return dataset


def test_dataset_length_10000(tmp_dataset_10000):
assert tmp_dataset_10000.len() == 10000


def test_chunks_are_saved_10000(tmp_dataset_10000):
chunk_files = [
f for f in os.listdir(tmp_dataset_10000.processed_dir)
if f.startswith("chunk_") and f.endswith(".pt")
]
# 10000 / 500 = 20 chunks
assert len(chunk_files) == 20


def test_dataset_get_10000(tmp_dataset_10000):
for i in range(0, 10000, 1000): # test every 1000th sample for speed
data = tmp_dataset_10000.get(i)
assert isinstance(data, Data)
assert data.x.shape == (5, 11)
52 changes: 52 additions & 0 deletions test/data/core/test_qm9_robust_on_disk_dataset_100000.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import torch
import pytest
from torch_geometric.data import Data
from topobench.data.core.qm9_robust_on_disk_dataset import QM9RobustOnDiskDataset


@pytest.fixture
def tmp_dataset_100000(tmp_path, monkeypatch):
"""
Fixture creating 100000 synthetic samples with chunk_size=5000.
"""

class FakeQM9:
raw_file_names = [f"mol_{i}.xyz" for i in range(100000)]

def __init__(self, root):
self.root = root
os.makedirs(os.path.join(root, "raw"), exist_ok=True)
self.raw_paths = []
for name in self.raw_file_names:
path = os.path.join(root, "raw", name)
tensor = torch.rand(5, 11)
torch.save(tensor, path)
self.raw_paths.append(path)

monkeypatch.setattr(
"topobench.data.core.qm9_robust_on_disk_dataset.QM9", lambda root: FakeQM9(root)
)

dataset = QM9RobustOnDiskDataset(root=str(tmp_path / "qm9_100000"), chunk_size=5000)
return dataset


def test_dataset_length_100000(tmp_dataset_100000):
assert tmp_dataset_100000.len() == 100000


def test_chunks_are_saved_100000(tmp_dataset_100000):
chunk_files = [
f for f in os.listdir(tmp_dataset_100000.processed_dir)
if f.startswith("chunk_") and f.endswith(".pt")
]
# 100000 / 5000 = 20 chunks
assert len(chunk_files) == 20


def test_dataset_get_100000(tmp_dataset_100000):
for i in range(0, 100000, 10000): # test every 10000th sample
data = tmp_dataset_100000.get(i)
assert isinstance(data, Data)
assert data.x.shape == (5, 11)
32 changes: 32 additions & 0 deletions topobench/data/core/qm9_robust_on_disk_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

import torch
from torch_geometric.data import Data
from torch_geometric.datasets import QM9

from .robust_on_disk_dataset import RobustOnDiskDataset


class QM9RobustOnDiskDataset(RobustOnDiskDataset):
"""
QM9 dataset using RobustOnDiskDataset.
Fully compatible with QM9 features, creates chunks from raw files.
"""

def __init__(self, root: str, chunk_size: int = 512, **kwargs):
self.qm9 = QM9(root=root)
super().__init__(root=root, chunk_size=chunk_size, **kwargs)

@property
def raw_file_names(self):
return self.qm9.raw_file_names

def prepare_raw_data(self):
return [os.path.join(self.root, "raw", f) for f in self.raw_file_names]

def process_raw_file(self, raw_path: str):
# Load .pt file and wrap in Data if needed
data = torch.load(raw_path)
if not isinstance(data, Data):
data = Data(x=data)
return [data]
102 changes: 102 additions & 0 deletions topobench/data/core/robust_on_disk_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
from collections.abc import Callable

import torch
from torch_geometric.data import Data, Dataset


class RobustOnDiskDataset(Dataset):
"""
A robust Chunk-based on-disk dataset that processes raw samples
individually and saves them in chunks to avoid memory bottlenecks.
"""

def __init__(
self,
root: str,
chunk_size: int = 512,
transform: Callable | None = None,
pre_transform: Callable | None = None,
pre_filter: Callable | None = None,
):
self.chunk_size = chunk_size
super().__init__(root, transform, pre_transform, pre_filter)

if not self._is_processed():
self.process()

@property
def processed_dir(self) -> str:
return os.path.join(self.root, "processed")

@property
def processed_file_names(self) -> list[str]:
if not os.path.exists(self.processed_dir):
return []
files = [
f
for f in os.listdir(self.processed_dir)
if f.startswith("chunk_") and f.endswith(".pt")
]
return sorted(files)

def _is_processed(self) -> bool:
return len(self.processed_file_names) > 0

def len(self) -> int:
# Sum of all samples in chunks
total = 0
for fname in self.processed_file_names:
chunk = torch.load(os.path.join(self.processed_dir, fname))
total += len(chunk)
return total

def get(self, idx: int) -> Data:
if idx < 0 or idx >= self.len():
raise IndexError(f"Index {idx} out of range")
chunk_id = idx // self.chunk_size
within = idx % self.chunk_size
chunk_path = os.path.join(self.processed_dir, f"chunk_{chunk_id}.pt")
chunk = torch.load(chunk_path)
data = chunk[within]
if self.transform:
data = self.transform(data)
return data

# -------------------------
# To be implemented by subclass
# -------------------------
def prepare_raw_data(self) -> list[str]:
raise NotImplementedError

def process_raw_file(self, raw_path: str) -> list[Data]:
raise NotImplementedError

# -------------------------
# Main processing logic
# -------------------------
def process(self):
os.makedirs(self.processed_dir, exist_ok=True)
raw_paths = self.prepare_raw_data()
all_data = []

for raw_path in raw_paths:
samples = self.process_raw_file(raw_path)
for data in samples:
if self.pre_filter and not self.pre_filter(data):
continue
if self.pre_transform:
data = self.pre_transform(data)
all_data.append(data)

while len(all_data) >= self.chunk_size:
self._save_chunk(all_data[: self.chunk_size])
all_data = all_data[self.chunk_size:]

if len(all_data) > 0:
self._save_chunk(all_data)

def _save_chunk(self, chunk: list[Data]):
chunk_id = len(self.processed_file_names)
path = os.path.join(self.processed_dir, f"chunk_{chunk_id}.pt")
torch.save(chunk, path)