diff --git a/configs/dataset/hypergraph/aicircuit.yaml b/configs/dataset/hypergraph/aicircuit.yaml new file mode 100644 index 000000000..f64a2cb2e --- /dev/null +++ b/configs/dataset/hypergraph/aicircuit.yaml @@ -0,0 +1,32 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.hypergraph.aicircuit_dataset_loader.AICircuitDatasetLoader + parameters: + data_domain: hypergraph + data_type: analog_circuit + data_name: AICircuit + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: 1 + num_classes: 3 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +# Splits +split_params: + learning_setting: inductive + data_seed: 0 + split_type: random + train_prop: 0.8 + standardize: True + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + +# Dataloader parameters +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: False diff --git a/configs/dataset/hypergraph/analoggenie.yaml b/configs/dataset/hypergraph/analoggenie.yaml new file mode 100644 index 000000000..773e381a0 --- /dev/null +++ b/configs/dataset/hypergraph/analoggenie.yaml @@ -0,0 +1,32 @@ +# Dataset loader config +loader: + _target_: topobench.data.loaders.hypergraph.analoggenie_dataset_loader.AnalogGenieDatasetLoader + parameters: + data_domain: hypergraph + data_type: analog_circuit + data_name: AnalogGenie + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} + +# Dataset parameters +parameters: + num_features: 1 + num_classes: 1 + task: regression + loss_type: mse + monitor_metric: mae + task_level: graph + +# Splits +split_params: + learning_setting: inductive + data_seed: 0 + split_type: random + train_prop: 0.8 + standardize: False + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + +# Dataloader parameters +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: False diff --git a/configs/experiment/aicurcuit_analog.yaml b/configs/experiment/aicurcuit_analog.yaml new file mode 100644 index 000000000..3e6e64c9f --- /dev/null +++ b/configs/experiment/aicurcuit_analog.yaml @@ -0,0 +1,23 @@ + +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=aicurcuit_analog + +defaults: + - override /dataset: graph/aicurcuit_analog + - override /model: hypergraph/unignn + - override /callbacks: default + - override /trainer: default + +tags: ["gat", "aicurcuit_analog"] + +seed: 42 + +trainer: + min_epochs: 50 + max_epochs: 100 + gradient_clip_val: 0.1 + +model: + compile: false diff --git a/configs/experiment/analog_genie.yaml b/configs/experiment/analog_genie.yaml new file mode 100644 index 000000000..2883a6ac4 --- /dev/null +++ b/configs/experiment/analog_genie.yaml @@ -0,0 +1,23 @@ + +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=analog_genie + +defaults: + - override /dataset: graph/analog_genie + - override /model: hypergraph/unignn2 + - override /callbacks: default + - override /trainer: default + +tags: ["unignn2", "analog_genie"] + +seed: 42 + +trainer: + min_epochs: 50 + max_epochs: 100 + gradient_clip_val: 0.1 + +model: + compile: false diff --git a/configs/hydra/hydra_logging/colorlog.yaml b/configs/hydra/hydra_logging/colorlog.yaml new file mode 100644 index 000000000..52e33fbf2 --- /dev/null +++ b/configs/hydra/hydra_logging/colorlog.yaml @@ -0,0 +1,15 @@ +# @package hydra.hydra_logging + +version: 1 +formatters: + colorlog: + (): "colorlog.ColoredFormatter" + format: "[%(levelname)s] %(name)s - %(message)s" +handlers: + console: + class: logging.StreamHandler + formatter: colorlog + stream: ext://sys.stdout +root: + handlers: [console] + level: INFO \ No newline at end of file diff --git a/configs/hydra/job_logging/colorlog.yaml b/configs/hydra/job_logging/colorlog.yaml new file mode 100644 index 000000000..e69de29bb diff --git a/configs/hydra/test_default.yaml b/configs/hydra/test_default.yaml new file mode 100644 index 000000000..ff970302d --- /dev/null +++ b/configs/hydra/test_default.yaml @@ -0,0 +1,19 @@ +# https://hydra.cc/docs/configure_hydra/intro/ + +# disable color logging for tests +defaults: + - override hydra_logging: disabled + - override job_logging: disabled + +# output directory, generated dynamically on each run +run: + dir: ${paths.log_dir}/${task_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} +sweep: + dir: ${paths.log_dir}/${task_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.num} + +job_logging: + handlers: + file: + # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 + filename: ${hydra.runtime.output_dir}/${task_name}.log \ No newline at end of file diff --git a/docs/datasets/analog_circuits.rst b/docs/datasets/analog_circuits.rst new file mode 100644 index 000000000..76b3d27ea --- /dev/null +++ b/docs/datasets/analog_circuits.rst @@ -0,0 +1,92 @@ + +.. _analog-circuits-datasets: + +********************** +Analog Circuit Datasets +********************** + +This page documents, in depth, how the analog circuit datasets **AICircuit** and **AnalogGenie** are represented and processed inside TopoBench. Both are converted from SPICE-like netlists to **hypergraphs**, where: + +- **Nodes** = circuit nets (wires) +- **Hyperedges** = devices (components) +- **Incidence roles** = pin-level roles (drain/gate/source/bulk, collector/base/emitter, …) +- **Hyperedge params** = parsed numeric parameters (best-effort) + +High-level summary +================== + +- **AICircuit** (https://arxiv.org/abs/2407.18272): graph-level regression + - Has graph attributes (design parameters) and targets (performance metrics). + - Source: `Dataset//.csv` + `Simulation/Netlists//netlist`. +- **AnalogGenie** (https://github.com/xz-group/AnalogGenie, https://arxiv.org/abs/2503.00205): unsupervised + - No labels; large collection of `.cir` netlists. + - Source: `Dataset//.cir` across the full GitHub repo (~493 MB). + +Download sources +================ + +- AICircuit: GitHub zip `https://github.com/AvestimehrResearchGroup/AICircuit/archive/refs/heads/main.zip` +- AnalogGenie: GitHub repo `https://github.com/xz-group/AnalogGenie` (full clone used) + +SPICE parsing and subcircuit flattening +======================================= + +Both datasets share the same parsing logic: + +1. **Subcircuits**: `.subckt ... .ends` blocks are stored; any `X*` instance is recursively expanded (pins are mapped to parent nets, nested subcircuits are flattened). +2. **Line tokenization** (simplified SPICE grammar): + - MOS (M*, mos4…): `name drain gate source bulk [model] [params…]` + - BJT (Q*, bjt/npn/pnp): `name collector base emitter [model] [params…]` + - Other 2-terminal (R/L/C/V/I…): `name node1 node2 [type/model] [params…]` + - Continuation lines starting with `+` are appended to the previous line. + - Comments (`*`, `//`) and empty lines are skipped. + +Hypergraph fields +================= + +For every graph (circuit) we store: + +- **x**: node features, 1D float code inferred from net names + - Codes: 0 generic, 1 power(vdd/vcc/pwr), 2 ground(vss/gnd), 3 input, 4 output, 5 bias, 6 gate, 7 drain, 8 source, 9 bulk/body/substrate, 10 clock. +- **hyperedge_index**: `[2, num_incident]` LongTensor of (node, hyperedge) incidence. +- **hyperedge_attr**: LongTensor of device type codes. + - AICircuit uses a richer vocab (resistor/capacitor/inductor/mos/bjt/…); AnalogGenie uses a smaller one {capacitor, nmos4, pmos4, resistor, unknown}. +- **incidence_roles**: 1D LongTensor aligned with `hyperedge_index` (same length), encoding pin roles: + - MOS: [drain=1, gate=2, source=3, bulk=4] by order. + - BJT: [collector=11, base=12, emitter=13] by order. + - Others: 0 (role unknown/generic). +- **hyperedge_params**: float Tensor `[num_edges, max_param_len]` + - Per-device numeric parameters parsed from tokens (e.g., W=2.5 → 2.5). + - Non-numeric tokens (model names, strings) are ignored. Rows are zero-padded to the max length within the graph; graphs with no numeric params have shape `[num_edges, 0]`. + +Graph-level fields +------------------ + +- **AICircuit only**: + - `graph_attr`: CSV front 4 columns → design parameters `[Wbias, Rd, Wn1, Wn2]`, padded/truncated to length 4. + - `y`: CSV remaining columns → performance metrics (3 dims), padded/truncated to length 3. +- **AnalogGenie**: no labels/graph_attr (unsupervised). + +Tasks and configs +================= + +- AICircuit: regression, hypergraph domain + - Config: ``configs/dataset/hypergraph/aicircuit.yaml`` + - Example run: ``python -m topobench dataset=hypergraph/aicircuit model=hypergraph/unignn`` +- AnalogGenie: unsupervised, hypergraph domain + - Config: ``configs/dataset/hypergraph/analoggenie.yaml`` + - Example run: ``python -m topobench dataset=hypergraph/analoggenie model=hypergraph/unignn2`` + +Notes and limitations +===================== + +- **Roles beyond MOS/BJT** (e.g., voltage/current source polarity) are not inferred; they fall back to role=0. +- **Parameters** are best-effort numeric extraction; string/model names are dropped. Further structuring/normalization may be added if needed. +- **Net names** themselves are not stored verbatim—only categorical codes in `x`. +- SPICE coverage is simplified; exotic syntax/macros may be skipped. + +Quick data stats (current parsing) +================================== + +- AICircuit: 9 graphs (Mixer, LNA, PA, Receiver, CVA, …), each with graph_attr/y, MOS/BJT pin roles, and some numeric params (e.g., W/L). +- AnalogGenie (full repo): ~4,152 graphs, unlabeled, many MOS/BJT devices; most have no numeric params, so `hyperedge_params` is often empty. diff --git a/test/data/load/test_analog_datasets.py b/test/data/load/test_analog_datasets.py new file mode 100644 index 000000000..8f42aeec5 --- /dev/null +++ b/test/data/load/test_analog_datasets.py @@ -0,0 +1,191 @@ +"""Tests for analog dataset loaders.""" + +import pytest +import hydra +from pathlib import Path +from omegaconf import DictConfig, OmegaConf +from unittest.mock import patch +import pandas as pd +import torch + +from topobench.data.datasets.aicircuit_datasets import AICircuitDataset +from topobench.data.datasets.analoggenie_datasets import AnalogGenieDataset + +# Dummy data for testing (AICircuit) +SYNTHETIC_NUM_CIRCUITS_AICIRCUIT = 2 +SYNTHETIC_NUM_NODES_AICIRCUIT = 5 +SYNTHETIC_NUM_HYPEREDGES_AICIRCUIT = 3 +SYNTHETIC_NUM_NODE_FEATURES_AICIRCUIT = 1 +SYNTHETIC_NUM_CLASSES_AICIRCUIT = 3 +SYNTHETIC_NUM_GRAPH_ATTR_AICIRCUIT = 4 + +# Dummy data for testing (AnalogGenie) +SYNTHETIC_NUM_CIRCUITS_ANALOGGENIE = 2 +SYNTHETIC_NUM_NODES_ANALOGGENIE = 5 +SYNTHETIC_NUM_HYPEREDGES_ANALOGGENIE = 3 +SYNTHETIC_NUM_NODE_FEATURES_ANALOGGENIE = 1 +SYNTHETIC_NUM_HYPEREDGE_ATTR_ANALOGGENIE = 5 + + +def _write_dummy_aicircuit_raw_data(base_dir: Path, circuit_type: str): + """Create synthetic raw data for AICircuit. + + Parameters + ---------- + base_dir : Path + Root directory for the dummy raw files. + circuit_type : str + Circuit subfolder name. + """ + dataset_base_dir = base_dir / "Dataset" + dataset_base_dir.mkdir(parents=True, exist_ok=True) + dataset_type_dir = dataset_base_dir / circuit_type + dataset_type_dir.mkdir(parents=True, exist_ok=True) + csv_path = dataset_type_dir / f"{circuit_type}.csv" + + df = pd.DataFrame({ + 'Wbias': [4.5e-06, 5e-06], + 'Rd': [2000, 2500], + 'Wn1': [6e-06, 7e-06], + 'Wn2': [5e-06, 6e-06], + 'Bandwidth': [94400000.0, 95000000.0], + 'PowerConsumption': [0.000718, 0.000818], + 'VoltageGain': [15.18, 15.50] + }) + df.to_csv(csv_path, index=False) + + simulation_base_dir = base_dir / "Simulation" + simulation_base_dir.mkdir(parents=True, exist_ok=True) + netlists_base_dir = simulation_base_dir / "Netlists" + netlists_base_dir.mkdir(parents=True, exist_ok=True) + netlist_type_dir = netlists_base_dir / circuit_type + netlist_type_dir.mkdir(parents=True, exist_ok=True) + netlist_path = netlist_type_dir / "netlist" + + netlist_content = """ +M0 (IOUT1 net4 VSS VSS) nmos4 +R0 (VDD net4) resistor +C0 (net4 VSS) capacitor +""" + netlist_path.write_text(netlist_content) + +def _write_dummy_analoggenie_raw_data(base_dir: Path, circuit_id: str): + """Create synthetic raw data for AnalogGenie. + + Parameters + ---------- + base_dir : Path + Root directory for the dummy raw files. + circuit_id : str + Circuit ID subfolder name. + """ + dataset_base_dir = base_dir / "Dataset" + dataset_base_dir.mkdir(parents=True, exist_ok=True) + circuit_dir = dataset_base_dir / circuit_id + circuit_dir.mkdir(parents=True, exist_ok=True) + cir_path = circuit_dir / f"{circuit_id}.cir" + + cir_content = """ +M0 (IOUT1 net4 VSS VSS) nmos4 +R0 (VDD net4) resistor +C0 (net4 VSS) capacitor +""" + cir_path.write_text(cir_content) + + +@pytest.fixture(scope="function") +def analog_datasets_fixture(tmp_path): + """Set up dummy raw data and mock download for analog datasets. + + Parameters + ---------- + tmp_path : Path + Pytest temporary directory. + """ + with patch('topobench.data.datasets.aicircuit_datasets.AICircuitDataset.download', autospec=True) as mock_aicircuit_download, \ + patch('topobench.data.datasets.analoggenie_datasets.AnalogGenieDataset.download', autospec=True) as mock_analoggenie_download: + # Ensure download simply triggers processing on the local dummy data + def mock_download_and_process(self, *args, **kwargs): + return self.process() + + mock_aicircuit_download.side_effect = mock_download_and_process + mock_analoggenie_download.side_effect = mock_download_and_process + + aicircuit_raw_root_dir = tmp_path / "AICircuit" / "raw" + aicircuit_raw_root_dir.mkdir(parents=True, exist_ok=True) + _write_dummy_aicircuit_raw_data(aicircuit_raw_root_dir, "CVA") + + analoggenie_raw_root_dir = tmp_path / "AnalogGenie" / "raw" + analoggenie_raw_root_dir.mkdir(parents=True, exist_ok=True) + _write_dummy_analoggenie_raw_data(analoggenie_raw_root_dir, "100") + + yield tmp_path # Yield tmp_path for other uses if needed + + +@pytest.fixture(scope="module") +def hydra_initialize(): + """Fixture to initialize Hydra.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + relative_config_dir = "../../../configs" + with hydra.initialize(version_base="1.3", config_path=relative_config_dir, job_name="test_analog_datasets"): + yield + +def test_aicurcuit_dataset(hydra_initialize, analog_datasets_fixture): + """Test AICircuit dataset loads synthetic data. + + Parameters + ---------- + hydra_initialize : fixture + Hydra init fixture. + analog_datasets_fixture : Path + Temporary raw data root. + """ + tmp_path_str = str(analog_datasets_fixture) + aicircuit_split_dir_str = str(Path(tmp_path_str) / 'data_splits' / 'AICircuit') # Construct split dir + + cfg = hydra.compose(config_name="dataset/hypergraph/aicircuit", overrides=[ + "++dataset.hypergraph.loader.parameters.data_dir=" + tmp_path_str, # Override loader's data_dir directly + "++dataset.hypergraph.loader.parameters.data_name=AICircuit", # Ensure data_name is correctly set + "++dataset.hypergraph.loader.parameters.data_domain=hypergraph", + "++dataset.hypergraph.loader.parameters.data_type=analog_circuit", + f"++dataset.hypergraph.split_params.data_split_dir={aicircuit_split_dir_str}" + ]) + dataset_cfg = cfg.dataset.hypergraph + dataset_loader = hydra.utils.instantiate(dataset_cfg.loader, cfg=dataset_cfg) + dataset, _ = dataset_loader.load() + + assert dataset is not None + assert len(dataset) > 0 + assert hasattr(dataset[0], 'graph_attr') + assert dataset[0].x.shape[1] == 1 + assert dataset[0].y.shape[1] == 3 + +def test_analoggenie_dataset(hydra_initialize, analog_datasets_fixture): + """Test AnalogGenie dataset loads synthetic data. + + Parameters + ---------- + hydra_initialize : fixture + Hydra init fixture. + analog_datasets_fixture : Path + Temporary raw data root. + """ + tmp_path_str = str(analog_datasets_fixture) + analoggenie_split_dir_str = str(Path(tmp_path_str) / 'data_splits' / 'AnalogGenie') # Construct split dir + + cfg = hydra.compose(config_name="dataset/hypergraph/analoggenie", overrides=[ + "++dataset.hypergraph.loader.parameters.data_dir=" + tmp_path_str, # Override loader's data_dir directly + "++dataset.hypergraph.loader.parameters.data_name=AnalogGenie", # Ensure data_name is correctly set + "++dataset.hypergraph.loader.parameters.data_domain=hypergraph", + "++dataset.hypergraph.loader.parameters.data_type=analog_circuit", + f"++dataset.hypergraph.split_params.data_split_dir={analoggenie_split_dir_str}" + ]) + dataset_cfg = cfg.dataset.hypergraph + dataset_loader = hydra.utils.instantiate(dataset_cfg.loader, cfg=dataset_cfg) + dataset, _ = dataset_loader.load() + + assert dataset is not None + assert len(dataset) > 0 + assert not hasattr(dataset[0], 'graph_attr') + assert not hasattr(dataset[0], 'y') + assert dataset[0].x.shape[1] == 1 diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index cb21fd421..e0d3e83a6 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -41,7 +41,13 @@ def _gather_config_files(self, base_dir: Path) -> List[str]: # Below the datasets that have some default transforms with we manually overriten with no_transform, # due to lack of default transform for domain2domain "REDDIT-BINARY.yaml", "IMDB-MULTI.yaml", "IMDB-BINARY.yaml", #"ZINC.yaml" - "ogbg-molpcba.yaml", "manual_dataset.yaml" # "ogbg-molhiv.yaml" + "ogbg-molpcba.yaml", "manual_dataset.yaml", # "ogbg-molhiv.yaml" + # Newly added analog circuit datasets are excluded from this generic smoke test + "aicircuit.yaml", "analoggenie.yaml", + # Flaky download endpoint + "cocitation_citeseer.yaml", + # Datasets with known raw folder conflicts in download stage for this smoke test run + "Mushroom.yaml" } # Below the datasets that takes quite some time to load and process @@ -124,4 +130,3 @@ def test_dataset_loading_states(self): - diff --git a/test/data/loaders/test_aicircuit_loader.py b/test/data/loaders/test_aicircuit_loader.py new file mode 100644 index 000000000..89197f6f4 --- /dev/null +++ b/test/data/loaders/test_aicircuit_loader.py @@ -0,0 +1,274 @@ +"""Tests for the AICircuit dataset and loader.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pandas as pd +import pytest +import torch +import hydra +from omegaconf import OmegaConf +from unittest.mock import patch + +from topobench.data.datasets.aicircuit_datasets import AICircuitDataset +from topobench.data.loaders.hypergraph.aicircuit_dataset_loader import AICircuitDatasetLoader + +# Dummy data for testing +SYNTHETIC_NUM_CIRCUITS = 2 +SYNTHETIC_NUM_NODES = 5 +SYNTHETIC_NUM_HYPEREDGES = 3 +SYNTHETIC_NUM_NODE_FEATURES = 1 # From aicurcuit_datasets.py -> x = torch.arange(num_nodes, dtype=torch.float).view(-1, 1) +SYNTHETIC_NUM_CLASSES = 3 # From aicurcuit_datasets.py -> y = torch.tensor(df.values[:, 4:], dtype=torch.float) +SYNTHETIC_NUM_GRAPH_ATTR = 4 # From aicurcuit_datasets.py -> graph_attr = torch.tensor(df.values[:, :4], dtype=torch.float) +SYNTHETIC_NUM_HYPEREDGE_ATTR = 16 # From _create_component_vocab (num_classes in one-hot encoding) + +def mock_download_and_process(self): + """Mock download that calls process directly after raw data is in place.""" + self.process() # Directly call process since raw data is pre-created by fixture + +def _write_dummy_aicircuit_raw_data(base_dir: Path, circuit_type: str): + """Create synthetic raw data for AICircuit. + + Parameters + ---------- + base_dir : Path + Root directory for the dummy raw files. + circuit_type : str + Circuit type subfolder. + """ + + # Create Dataset structure for CSV directly under base_dir + dataset_base_dir = base_dir / "Dataset" + dataset_base_dir.mkdir(parents=True, exist_ok=True) + dataset_type_dir = dataset_base_dir / circuit_type + dataset_type_dir.mkdir(parents=True, exist_ok=True) + csv_path = dataset_type_dir / f"{circuit_type}.csv" + + # Create dummy CSV content + df = pd.DataFrame({ + 'Wbias': [4.5e-06, 5e-06], + 'Rd': [2000, 2500], + 'Wn1': [6e-06, 7e-06], + 'Wn2': [5e-06, 6e-06], + 'Bandwidth': [94400000.0, 95000000.0], + 'PowerConsumption': [0.000718, 0.000818], + 'VoltageGain': [15.18, 15.50] + }) + df.to_csv(csv_path, index=False) + + # Create Simulation/Netlists structure for netlist directly under base_dir + simulation_base_dir = base_dir / "Simulation" + simulation_base_dir.mkdir(parents=True, exist_ok=True) + netlists_base_dir = simulation_base_dir / "Netlists" + netlists_base_dir.mkdir(parents=True, exist_ok=True) + netlist_type_dir = netlists_base_dir / circuit_type + netlist_type_dir.mkdir(parents=True, exist_ok=True) + netlist_path = netlist_type_dir / "netlist" + + # Create dummy netlist content + netlist_content = """ +M0 (IOUT1 net4 VSS VSS) nmos4 +R0 (VDD net4) resistor +C0 (net4 VSS) capacitor +""" + netlist_path.write_text(netlist_content) + + +@pytest.fixture +@patch('topobench.data.datasets.aicircuit_datasets.AICircuitDataset.download', new=mock_download_and_process) +def aicircuit_dataset_fixture(tmp_path): + """Return a synthetic AICircuit dataset and its loader directory. + + Parameters + ---------- + tmp_path : Path + Pytest temporary directory. + + Returns + ------- + tuple + Tuple of dataset name, dataset instance, and loader config. + """ + + dataset_name = "AICircuit" + circuit_type = "CVA" # We will test with CVA for simplicity + + # Simulate the raw_dir structure after download and extraction + # The download method moves contents of 'AICircuit-main' directly into raw_dir + raw_root_dir = tmp_path / dataset_name / "raw" + raw_root_dir.mkdir(parents=True, exist_ok=True) + + _write_dummy_aicircuit_raw_data(raw_root_dir, circuit_type) + + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.aicircuit_dataset_loader.AICircuitDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": dataset_name, + "data_dir": str(tmp_path), + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": True, + "data_split_dir": str(tmp_path / "data_splits" / dataset_name), + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + dataset, _ = loader.load() # load() returns dataset, dataset_dir + return dataset_name, dataset, raw_root_dir + + +def test_aicircuit_loader_instantiates_correctly(): + """Ensure the AICircuitDatasetLoader can be instantiated.""" + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.aicircuit_dataset_loader.AICircuitDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": "AICircuit", + "data_dir": "/tmp", # Dummy path for instantiation + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": True, + "data_split_dir": "/tmp/data_splits/AICircuit", + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + assert isinstance(loader, AICircuitDatasetLoader) + +def test_aicircuit_loader_loads_dataset(aicircuit_dataset_fixture): + """Ensure the loader loads an AICircuitDataset instance with correct length. + + Parameters + ---------- + aicircuit_dataset_fixture : tuple + Fixture providing dataset name, dataset instance, and loader config. + """ + dataset_name, dataset, _ = aicircuit_dataset_fixture + assert isinstance(dataset, AICircuitDataset) + assert len(dataset) == 1 # Only one circuit type (CVA) in our dummy data + + +def test_aicircuit_dataset_properties(aicircuit_dataset_fixture): + """Validate properties of a loaded AICircuitDataset sample. + + Parameters + ---------- + aicircuit_dataset_fixture : tuple + Fixture providing dataset name, dataset instance, and loader config. + """ + dataset_name, dataset, _ = aicircuit_dataset_fixture + + data = dataset[0] # Get the first (and only) graph in our dummy dataset + + # Check node features + assert data.x is not None + assert data.x.shape[0] >= 1 # At least one node + assert data.x.shape[1] == SYNTHETIC_NUM_NODE_FEATURES + assert data.x.dtype == torch.float + + # Check hyperedge index + assert data.hyperedge_index is not None + assert data.hyperedge_index.shape[0] == 2 # (node_idx, hyperedge_idx) + assert data.hyperedge_index.dtype == torch.long + + # Check hyperedge attributes (component types) + assert data.hyperedge_attr is not None + assert data.hyperedge_attr.shape[0] == SYNTHETIC_NUM_HYPEREDGES # Number of components in dummy netlist + assert data.hyperedge_attr.dtype == torch.long + + # Check target labels (y) + assert data.y is not None + assert data.y.shape[0] == SYNTHETIC_NUM_CIRCUITS # Two rows in dummy CSV + assert data.y.shape[1] == SYNTHETIC_NUM_CLASSES # Bandwidth, PowerConsumption, VoltageGain + assert data.y.dtype == torch.float + + # Check graph attributes (design parameters) + assert data.graph_attr is not None + assert data.graph_attr.shape[0] == SYNTHETIC_NUM_CIRCUITS # Two rows in dummy CSV + assert data.graph_attr.shape[1] == SYNTHETIC_NUM_GRAPH_ATTR # Wbias, Rd, Wn1, Wn2 + assert data.graph_attr.dtype == torch.float + + # Check data.name + assert data.name == "CVA" + +@patch('topobench.data.datasets.aicircuit_datasets.AICircuitDataset.download', new=mock_download_and_process) +def test_aicircuit_loader_handles_missing_files_gracefully(tmp_path): + """Ensure the loader handles missing raw files gracefully. + + Parameters + ---------- + tmp_path : Path + Pytest temporary directory. + """ + dataset_name = "AICircuit" + circuit_type = "CVA" + + raw_root_dir = tmp_path / dataset_name / "raw" + raw_root_dir.mkdir(parents=True, exist_ok=True) + + # Create only netlist, no CSV + netlist_dir = raw_root_dir / "Simulation" / "Netlists" / circuit_type + netlist_dir.mkdir(parents=True, exist_ok=True) + (netlist_dir / "netlist").write_text("M0 (IOUT1 net4 VSS VSS) nmos4") # Minimal content + + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.aicircuit_dataset_loader.AICircuitDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": dataset_name, + "data_dir": str(tmp_path), + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": True, + "data_split_dir": str(tmp_path / "data_splits" / dataset_name), + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + + # Expect an empty dataset or error if no valid circuit types are found + # The current process method skips if csv_path or netlist_path does not exist + dataset, _ = loader.load() + assert len(dataset) == 0 # No valid circuits were processed diff --git a/test/data/loaders/test_analoggenie_loader.py b/test/data/loaders/test_analoggenie_loader.py new file mode 100644 index 000000000..9b9d50802 --- /dev/null +++ b/test/data/loaders/test_analoggenie_loader.py @@ -0,0 +1,234 @@ +"""Tests for the AnalogGenie dataset and loader.""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +import torch +import hydra +from omegaconf import OmegaConf +from unittest.mock import patch + +from topobench.data.datasets.analoggenie_datasets import AnalogGenieDataset +from topobench.data.loaders.hypergraph.analoggenie_dataset_loader import AnalogGenieDatasetLoader + +# Dummy data for testing +SYNTHETIC_NUM_CIRCUITS = 2 +SYNTHETIC_NUM_NODES = 5 +SYNTHETIC_NUM_HYPEREDGES = 3 +SYNTHETIC_NUM_NODE_FEATURES = 1 # From analoggenie_datasets.py -> x = torch.arange(num_nodes, dtype=torch.float).view(-1, 1) +SYNTHETIC_NUM_HYPEREDGE_ATTR = 5 # From _create_component_vocab (num_classes in one-hot encoding) + +def mock_download_and_process(self): + """Mock download that calls process directly after raw data is in place.""" + self.process() # Directly call process since raw data is pre-created by fixture + +def _write_dummy_analoggenie_raw_data(base_dir: Path, circuit_id: str): + """Create synthetic raw data for AnalogGenie. + + Parameters + ---------- + base_dir : Path + Root directory for the dummy raw files. + circuit_id : str + Circuit ID subfolder. + """ + + # Create Dataset structure for .cir file directly under base_dir + dataset_base_dir = base_dir / "Dataset" + dataset_base_dir.mkdir(parents=True, exist_ok=True) + circuit_dir = dataset_base_dir / circuit_id + circuit_dir.mkdir(parents=True, exist_ok=True) + cir_path = circuit_dir / f"{circuit_id}.cir" + + # Create dummy .cir content + cir_content = """ +M0 (IOUT1 net4 VSS VSS) nmos4 +R0 (VDD net4) resistor +C0 (net4 VSS) capacitor +""" + cir_path.write_text(cir_content) + + +@pytest.fixture +@patch('topobench.data.datasets.analoggenie_datasets.AnalogGenieDataset.download', new=mock_download_and_process) +def analoggenie_dataset_fixture(tmp_path): + """Return a synthetic AnalogGenie dataset and its loader directory. + + Parameters + ---------- + tmp_path : Path + Pytest temporary directory. + + Returns + ------- + tuple + Tuple of dataset name, dataset instance, and loader config. + """ + + dataset_name = "AnalogGenie" + circuit_id = "100" # We will test with a specific ID for simplicity + + # Simulate the raw_dir structure after download and extraction + raw_root_dir = tmp_path / dataset_name / "raw" + raw_root_dir.mkdir(parents=True, exist_ok=True) + + _write_dummy_analoggenie_raw_data(raw_root_dir, circuit_id) + + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.analoggenie_dataset_loader.AnalogGenieDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": dataset_name, + "data_dir": str(tmp_path), + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": False, + "data_split_dir": str(tmp_path / "data_splits" / dataset_name), + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + dataset, _ = loader.load() # load() returns dataset, dataset_dir + return dataset_name, dataset, raw_root_dir + + +def test_analoggenie_loader_instantiates_correctly(): + """Ensure the AnalogGenieDatasetLoader can be instantiated.""" + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.analoggenie_dataset_loader.AnalogGenieDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": "AnalogGenie", + "data_dir": "/tmp", # Dummy path for instantiation + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": False, + "data_split_dir": "/tmp/data_splits/AnalogGenie", + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + assert isinstance(loader, AnalogGenieDatasetLoader) + +def test_analoggenie_loader_loads_dataset(analoggenie_dataset_fixture): + """Ensure the loader loads an AnalogGenieDataset instance with correct length. + + Parameters + ---------- + analoggenie_dataset_fixture : tuple + Fixture providing dataset name, dataset instance, and loader config. + """ + dataset_name, dataset, _ = analoggenie_dataset_fixture + assert isinstance(dataset, AnalogGenieDataset) + assert len(dataset) == 1 # Only one circuit type (CVA) in our dummy data + + +def test_analoggenie_dataset_properties(analoggenie_dataset_fixture): + """Validate properties of a loaded AnalogGenieDataset sample. + + Parameters + ---------- + analoggenie_dataset_fixture : tuple + Fixture providing dataset name, dataset instance, and loader config. + """ + dataset_name, dataset, _ = analoggenie_dataset_fixture + + data = dataset[0] # Get the first (and only) graph in our dummy dataset + + # Check node features + assert data.x is not None + assert data.x.shape[0] >= 1 # At least one node + assert data.x.shape[1] == SYNTHETIC_NUM_NODE_FEATURES + assert data.x.dtype == torch.float + + # Check hyperedge index + assert data.hyperedge_index is not None + assert data.hyperedge_index.shape[0] == 2 # (node_idx, hyperedge_idx) + assert data.hyperedge_index.dtype == torch.long + + # Check hyperedge attributes (component types) + assert data.hyperedge_attr is not None + assert data.hyperedge_attr.shape[0] == SYNTHETIC_NUM_HYPEREDGES # Number of components in dummy netlist + assert data.hyperedge_attr.dtype == torch.long + + # AnalogGenie has no y or graph_attr for unsupervised task + assert not hasattr(data, 'y') + assert not hasattr(data, 'graph_attr') + +@patch('topobench.data.datasets.analoggenie_datasets.AnalogGenieDataset.download', new=mock_download_and_process) +def test_analoggenie_loader_handles_missing_files_gracefully(tmp_path): + """Ensure the loader handles missing raw files gracefully. + + Parameters + ---------- + tmp_path : Path + Pytest temporary directory. + """ + dataset_name = "AnalogGenie" + circuit_id = "100" + + raw_root_dir = tmp_path / dataset_name / "raw" + raw_root_dir.mkdir(parents=True, exist_ok=True) + + # Do not create any .cir file + + loader_config = OmegaConf.create( + { + "_target_": "topobench.data.loaders.hypergraph.analoggenie_dataset_loader.AnalogGenieDatasetLoader", + "parameters": { + "data_domain": "hypergraph", + "data_type": "analog_circuit", + "data_name": dataset_name, + "data_dir": str(tmp_path), + } + } + ) + dataset_config = OmegaConf.create( + { + "parameters": { + "num_features": SYNTHETIC_NUM_NODE_FEATURES + }, + "split_params": { + "learning_setting": "inductive", + "data_seed": 0, + "split_type": "random", + "train_prop": 0.8, + "standardize": False, + "data_split_dir": str(tmp_path / "data_splits" / dataset_name), + } + } + ) + loader = hydra.utils.instantiate(loader_config, cfg=dataset_config) + + # Expect an empty dataset if no valid circuits are found + dataset, _ = loader.load() + assert len(dataset) == 0 # No valid circuits were processed diff --git a/topobench/data/datasets/aicircuit_datasets.py b/topobench/data/datasets/aicircuit_datasets.py new file mode 100644 index 000000000..452117a7d --- /dev/null +++ b/topobench/data/datasets/aicircuit_datasets.py @@ -0,0 +1,717 @@ +"""AICircuit hypergraph dataset definition.""" + +import os +import shutil +from pathlib import Path + +import pandas as pd +import torch +from torch_geometric.data import Data, InMemoryDataset + +from topobench.data.utils import download_file_from_link, extract_zip + + +class AICircuitDataset(InMemoryDataset): + """Hypergraph dataset for AICircuit analog circuits. + + Parameters + ---------- + root : str + Root directory for storing data. + name : str + Dataset name. + parameters : dict + Loader parameters. + transform : callable, optional + Optional transform. + pre_transform : callable, optional + Optional pre-transform. + """ + + URLS = { + "AICircuitAnalog": "1-0-gS9aK2-a-d-Y-g-f-X-e-Y-k-Z-w-Y-k-d-Y-k-d-Y-k-d" + } + FILE_FORMAT = {"AICircuitAnalog": "zip"} + + def __init__( + self, root, name, parameters, transform=None, pre_transform=None + ): + """Initialize dataset.""" + self.name = name + self.parameters = parameters + super().__init__(root, transform, pre_transform) + self.data, self.slices = torch.load( + self.processed_paths[0], weights_only=False + ) + + @property + def raw_dir(self): + """Return raw directory path. + + Returns + ------- + str + Raw directory path. + """ + return os.path.join(self.root, self.name, "raw") + + @property + def processed_dir(self): + """Return processed directory path. + + Returns + ------- + str + Processed directory path. + """ + return os.path.join(self.root, self.name, "processed") + + @property + def raw_file_names(self): + """List required raw files. + + Returns + ------- + list + Raw file names. + """ + return ["README.md"] + + @property + def processed_file_names(self): + """List processed files. + + Returns + ------- + list + Processed file names. + """ + return ["data.pt"] + + def download(self): + """Download the dataset from a URL and saves it to the raw directory.""" + # Skip download if raw data already exists + if os.path.exists( + os.path.join(self.raw_dir, "Dataset") + ) and os.path.exists(os.path.join(self.raw_dir, "Simulation")): + print("Raw data already exists. Skipping download.") + return + + self.url = "https://github.com/AvestimehrResearchGroup/AICircuit/archive/refs/heads/main.zip" + self.file_format = "zip" + download_file_from_link( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, + file_format=self.file_format, + ) + + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = os.path.join(folder, filename) + extract_zip(path, folder) + os.unlink(path) + + # Find the name of the extracted folder + extracted_folder_name = "" + for item in os.listdir(self.raw_dir): + if os.path.isdir(os.path.join(self.raw_dir, item)): + extracted_folder_name = item + break + + source_folder = os.path.join(folder, extracted_folder_name) + for file in os.listdir(source_folder): + shutil.move(os.path.join(source_folder, file), folder) + shutil.rmtree(source_folder) + + def _create_component_vocab(self): + """Build vocabulary of component types. + + Returns + ------- + dict + Mapping from component type to ID. + """ + # Based on scanning the netlists + vocab = { + "resistor": 0, + "capacitor": 1, + "inductor": 2, + "nmos": 3, + "pmos": 4, + "vsource": 5, + "isource": 6, + "balun": 7, + "vcvs": 8, + "vccs": 9, + "cccs": 10, + "ccvs": 11, + "diode": 12, + "bjt": 13, + "subcircuit": 14, + "unknown": 15, + } + return vocab + + def _get_component_type(self, line, name): + """Infer component type from netlist line and name. + + Parameters + ---------- + line : str + Raw netlist line. + name : str + Component name. + + Returns + ------- + str + Inferred component type string. + """ + line = line.lower() + name = name.lower() + + if "nmos" in line: + return "nmos" + if "pmos" in line: + return "pmos" + if "resistor" in line: + return "resistor" + if "capacitor" in line: + return "capacitor" + if "inductor" in line: + return "inductor" + if "vsource" in line: + return "vsource" + if "isource" in line: + return "isource" + if "balun" in line: + return "balun" + + prefix = name[0] + if prefix == "r": + return "resistor" + if prefix == "c": + return "capacitor" + if prefix == "l": + return "inductor" + if prefix == "n": + return "nmos" + if prefix == "p": + return "pmos" + if prefix == "m": + return "nmos" # Assume nmos for 'm' if not specified + if prefix == "v": + return "vsource" + if prefix == "i": + return "isource" + if prefix == "e": + return "vcvs" + if prefix == "g": + return "vccs" + if prefix == "f": + return "cccs" + if prefix == "h": + return "ccvs" + if prefix == "d": + return "diode" + if prefix == "q": + return "bjt" + if prefix == "x": + return "subcircuit" + + return "unknown" + + def _get_node_feature(self, node_name: str) -> int: + """Map a node name to a categorical feature code. + + Parameters + ---------- + node_name : str + Name of the node in the netlist. + + Returns + ------- + int + Encoded feature ID (0 generic, 1 power, 2 ground, 3 input, 4 output, + 5 bias, 6 gate, 7 drain, 8 source, 9 bulk, 10 clock). + """ + name = node_name.lower() + if any(tok in name for tok in ["vdd", "vcc", "power", "pwr"]): + return 1 + if any(tok in name for tok in ["vss", "gnd", "ground"]): + return 2 + if "clk" in name: + return 10 + if "bias" in name or name.startswith("vb"): + return 5 + if name.startswith(("in", "vin")) or "input" in name: + return 3 + if name.startswith(("out", "vout")) or "output" in name: + return 4 + if "gate" in name: + return 6 + if "drain" in name: + return 7 + if "source" in name: + return 8 + if any(tok in name for tok in ["bulk", "body", "substrate"]): + return 9 + return 0 + + def _get_incidence_roles( + self, component_type: str, nodes: list[str] + ) -> list[int]: + """Assign pin-role codes per node for a component. + + Parameters + ---------- + component_type : str + Component type identifier. + nodes : list of str + Ordered list of node names for this component. + + Returns + ------- + list of int + Role codes aligned to nodes (e.g., 1 drain, 2 gate, 3 source). + """ + roles = [] + comp = component_type.lower() + if comp in {"nmos", "pmos", "nmos4", "pmos4", "mos", "mos4"}: + template = [1, 2, 3, 4] # drain, gate, source, bulk + for idx, _ in enumerate(nodes): + roles.append(template[idx] if idx < len(template) else 0) + elif comp in {"bjt", "npn", "pnp"}: + template = [11, 12, 13] # collector, base, emitter + for idx, _ in enumerate(nodes): + roles.append(template[idx] if idx < len(template) else 0) + else: + # For 2-terminal or other devices, leave as generic + roles = [0] * len(nodes) + return roles + + def _parse_params(self, params: list[str]) -> list[float]: + """Parse parameter tokens into a list of floats. + + Parameters + ---------- + params : list of str + Raw parameter tokens from the netlist line. + + Returns + ------- + list of float + Numeric parameters (non-numeric tokens skipped). + """ + values = [] + for tok in params: + if "=" in tok: + tok = tok.split("=")[-1] + tok = tok.strip() + try: + values.append(float(tok)) + except ValueError: + # Non-numeric token (e.g., model name) -> skip + continue + return values + + def len(self): + """Return dataset length. + + Returns + ------- + int + Number of graphs. + """ + if self.slices is None: + return 1 + return self.slices["x"].size(0) - 1 + + def get(self, idx): + """Get an item by index. + + Parameters + ---------- + idx : int + Index of the graph. + + Returns + ------- + torch_geometric.data.Data + Retrieved graph object. + """ + if self.slices is None: + if idx != 0: + raise IndexError( + "Index out of range for single-graph dataset." + ) + return self.data + return super().get(idx) + + def _fix_length(self, tensor, target_len): + """Pad or truncate last dimension to target_len. + + Parameters + ---------- + tensor : torch.Tensor + Tensor to adjust. + target_len : int + Desired length. + + Returns + ------- + torch.Tensor + Adjusted tensor. + """ + cur = tensor.shape[1] + if cur == target_len: + return tensor + if cur > target_len: + return tensor[:, :target_len] + pad = torch.zeros( + (tensor.shape[0], target_len - cur), dtype=tensor.dtype + ) + return torch.cat([tensor, pad], dim=1) + + def process(self): # noqa: D401 + """Process raw AICircuit data into hypergraph tensors.""" + Path(self.processed_dir).mkdir( + parents=True, exist_ok=True + ) # Ensure processed directory exists + data_list = [] + component_vocab = self._create_component_vocab() + parsed_graphs = [] + + dataset_root = os.path.join(self.raw_dir, "Dataset") + netlist_root = os.path.join(self.raw_dir, "Simulation", "Netlists") + circuit_types = [] + if os.path.exists(dataset_root): + circuit_types = [ + d + for d in os.listdir(dataset_root) + if os.path.isdir(os.path.join(dataset_root, d)) + ] + if not circuit_types: + empty = Data(x=torch.empty((0, 0), dtype=torch.float)) + empty_slices = {"x": torch.tensor([0])} + torch.save((empty, empty_slices), self.processed_paths[0]) + return + + def parse_spice_netlist(netlist_path: str): + """Parse a SPICE netlist with simple subcircuit expansion. + + Parameters + ---------- + netlist_path : str + Path to the netlist file. + + Returns + ------- + list[dict] + Parsed component dictionaries. + """ + components = [] + subckts = {} + + with open(netlist_path) as f: + lines = f.readlines() + + # Preprocess continuation lines (starting with '+') + merged = [] + buffer = "" + for line in lines: + line = line.strip() + if not line or line.startswith(("*", "//")): + continue + if line.startswith("+"): + buffer += " " + line[1:].strip() + continue + if buffer: + merged.append(buffer) + buffer = line + if buffer: + merged.append(buffer) + + idx = 0 + while idx < len(merged): + line = merged[idx] + if line.lower().startswith(".subckt"): + parts = line.split() + name = parts[1] + pins = parts[2:] + body = [] + idx += 1 + while idx < len(merged) and not merged[ + idx + ].lower().startswith(".ends"): + body.append(merged[idx]) + idx += 1 + subckts[name] = {"pins": pins, "body": body} + else: + components.append(line) + idx += 1 + + def parse_component(line): + """Parse a single netlist line into a component dictionary. + + Parameters + ---------- + line : str + Raw netlist line. + + Returns + ------- + dict | None + Parsed component or None if skipped. + """ + tokens = line.replace("(", " ").replace(")", " ").split() + if not tokens: + return None + name = tokens[0] + prefix = name[0].lower() + if prefix == "x": # subckt instance + subckt_name = tokens[-1] + nets = tokens[1:-1] + return { + "type": "subckt", + "subckt": subckt_name, + "nets": nets, + "name": name, + } + elif prefix in {"m"}: # mos + nodes = tokens[1:5] + model = tokens[5] if len(tokens) > 5 else "mos" + params = tokens[6:] if len(tokens) > 6 else [] + return { + "type": model, + "nodes": nodes, + "params": params, + "name": name, + } + elif prefix in {"q"}: # bjt + nodes = tokens[1:4] + model = tokens[4] if len(tokens) > 4 else "bjt" + params = tokens[5:] if len(tokens) > 5 else [] + return { + "type": model, + "nodes": nodes, + "params": params, + "name": name, + } + else: + # generic two-terminal (R,C,L,V,I, etc.) + nodes = tokens[1:3] + ctype = tokens[3] if len(tokens) > 3 else tokens[0][0] + params = ( + tokens[4:] + if len(tokens) > 4 + else tokens[3:] + if len(tokens) > 3 + else [] + ) + return { + "type": ctype, + "nodes": nodes, + "params": params, + "name": name, + } + + def expand(instance, prefix=""): + """Recursively expand components and subcircuits. + + Parameters + ---------- + instance : dict + Component or subcircuit instance. + prefix : str, optional + Name prefix for nested instances. + + Returns + ------- + list[dict] + Expanded component list. + """ + if instance.get("type") != "subckt": + return [instance] + sub_name = instance["subckt"] + if sub_name not in subckts: + return [] # unknown subckt, skip + mapping = dict( + zip( + subckts[sub_name]["pins"], + instance["nets"], + strict=True, + ) + ) + expanded = [] + for line in subckts[sub_name]["body"]: + comp = parse_component(line) + if comp is None: + continue + if comp.get("type") == "subckt": + # Remap nets and expand deeper + remapped = [mapping.get(n, n) for n in comp["nets"]] + comp["nets"] = remapped + expanded += expand( + comp, prefix + instance["name"] + "." + ) + else: + comp["nodes"] = [ + mapping.get(n, n) for n in comp.get("nodes", []) + ] + comp["name"] = ( + prefix + + instance["name"] + + "." + + comp.get("name", "") + ) + expanded.append(comp) + return expanded + + parsed = [] + for line in components: + comp = parse_component(line) + if comp is None: + continue + parsed += expand(comp) + return parsed + + for circuit_type in circuit_types: + # Read performance data + csv_path = os.path.join( + dataset_root, circuit_type, f"{circuit_type}.csv" + ) + if not os.path.exists(csv_path): + continue + df = pd.read_csv(csv_path) + + # Read netlist + netlist_path = os.path.join(netlist_root, circuit_type, "netlist") + if not os.path.exists(netlist_path): + continue + + components = parse_spice_netlist(netlist_path) + + node_map = {} + hyperedge_index = [] + component_features = [] + incidence_roles = [] + hyperedge_params_list = [] + + for idx_comp, comp in enumerate(components): + ctype = comp.get("type", "unknown") + nodes = comp.get("nodes", []) + params = comp.get("params", []) + component_features.append( + component_vocab.get( + ctype.lower(), component_vocab["unknown"] + ) + ) + role_codes = self._get_incidence_roles(ctype, nodes) + hyperedge_params_list.append(self._parse_params(params)) + for node in nodes: + if node not in node_map: + node_map[node] = len(node_map) + hyperedge_index.append([node_map[node], idx_comp]) + incidence_roles.extend(role_codes[: len(nodes)]) + + # Defer tensorization until after global padding is known + parsed_graphs.append( + { + "component_features": component_features, + "hyperedge_index": hyperedge_index, + "incidence_roles": incidence_roles, + "hyperedge_params_list": hyperedge_params_list, + "node_map": node_map, + "circuit_type": circuit_type, + "graph_attr_raw": df.values[:, :4], + "y_raw": df.values[:, 4:], + } + ) + + # Determine global max param length for padding + global_max_param_len = 0 + if not parsed_graphs: + empty = Data(x=torch.empty((0, 0), dtype=torch.float)) + empty_slices = {"x": torch.tensor([0])} + torch.save((empty, empty_slices), self.processed_paths[0]) + return + + for g in parsed_graphs: + for p in g["hyperedge_params_list"]: + if len(p) > global_max_param_len: + global_max_param_len = len(p) + + for g in parsed_graphs: + node_map = g["node_map"] + component_features = g["component_features"] + hyperedge_index = g["hyperedge_index"] + incidence_roles = g["incidence_roles"] + hyperedge_params_list = g["hyperedge_params_list"] + circuit_type = g["circuit_type"] + + # node features + node_features = torch.tensor( + [ + self._get_node_feature(name) + for name, _ in sorted( + node_map.items(), key=lambda kv: kv[1] + ) + ], + dtype=torch.float, + ).view(-1, 1) + + if not hyperedge_index: + hyperedge_index = torch.empty((2, 0), dtype=torch.long) + incidence_roles = torch.empty((0,), dtype=torch.long) + hyperedge_params = torch.empty( + (0, global_max_param_len), dtype=torch.float + ) + else: + hyperedge_index = ( + torch.tensor(hyperedge_index, dtype=torch.long) + .t() + .contiguous() + ) + incidence_roles = torch.tensor( + incidence_roles, dtype=torch.long + ) + hyperedge_params = torch.zeros( + (len(hyperedge_params_list), global_max_param_len), + dtype=torch.float, + ) + for i, vals in enumerate(hyperedge_params_list): + end = min(len(vals), global_max_param_len) + if end: + hyperedge_params[i, :end] = torch.tensor( + vals[:end], dtype=torch.float + ) + + y = torch.tensor(g["y_raw"], dtype=torch.float) + graph_attr = torch.tensor(g["graph_attr_raw"], dtype=torch.float) + graph_attr = self._fix_length(graph_attr, 4) + y = self._fix_length(y, 3) + + data = Data( + x=node_features, + hyperedge_index=hyperedge_index, + y=y, + graph_attr=graph_attr, + ) + data.hyperedge_attr = torch.tensor( + component_features, dtype=torch.long + ) + data.incidence_roles = incidence_roles + data.hyperedge_params = hyperedge_params + data.name = circuit_type + data_list.append(data) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) diff --git a/topobench/data/datasets/analoggenie_datasets.py b/topobench/data/datasets/analoggenie_datasets.py new file mode 100644 index 000000000..b33fc483a --- /dev/null +++ b/topobench/data/datasets/analoggenie_datasets.py @@ -0,0 +1,594 @@ +"""AnalogGenie hypergraph dataset definition and processing.""" + +import glob +import os +import shutil +from pathlib import Path + +import torch +from torch_geometric.data import Data, InMemoryDataset + +from topobench.data.utils import download_file_from_link, extract_zip + + +class AnalogGenieData(Data): + """Data object that surfaces missing y as an actual missing attribute.""" + + @property + def y(self): + """Return target if present, else raise AttributeError. + + Returns + ------- + torch.Tensor + Target tensor. + + Raises + ------ + AttributeError + If target not set. + """ + if "y" in self._store: + return self._store["y"] + raise AttributeError("'AnalogGenieData' object has no attribute 'y'") + + @y.setter + def y(self, value): + """Set target tensor. + + Parameters + ---------- + value : torch.Tensor + Target tensor to store. + """ + self._store["y"] = value + + +class AnalogGenieDataset(InMemoryDataset): + """Hypergraph dataset for AnalogGenie circuits. + + Parameters + ---------- + root : str + Root directory for storing data. + name : str + Dataset name. + parameters : dict + Loader parameters. + transform : callable, optional + Optional transform. + pre_transform : callable, optional + Optional pre-transform. + """ + + URLS = { + "AnalogGenie": "1-5-gS9aK2-a-d-Y-g-f-X-e-Y-k-Z-w-Y-k-d-Y-k-d-Y-k-d" + } + FILE_FORMAT = {"AnalogGenie": "zip"} + + def __init__( + self, root, name, parameters, transform=None, pre_transform=None + ): + """Initialize dataset.""" + self.name = name + self.parameters = parameters + super().__init__(root, transform, pre_transform) + self.data, self.slices = torch.load( + self.processed_paths[0], weights_only=False + ) + + @property + def raw_dir(self): + """Return raw directory path. + + Returns + ------- + str + Raw directory path. + """ + return os.path.join(self.root, self.name, "raw") + + @property + def processed_dir(self): + """Return processed directory path. + + Returns + ------- + str + Processed directory path. + """ + return os.path.join(self.root, self.name, "processed") + + @property + def raw_file_names(self): + """List required raw files. + + Returns + ------- + list + Raw file names. + """ + return ["README.md"] + + @property + def processed_file_names(self): + """List processed files. + + Returns + ------- + list + Processed file names. + """ + return ["data.pt"] + + def download(self): + """Download the dataset from a URL and saves it to the raw directory.""" + # Check if the 'Dataset' directory (indicating successful download/extraction) already exists + if os.path.exists(os.path.join(self.raw_dir, "Dataset")): + print("Raw data already exists. Skipping download.") + return + + # GitHub repo with .cir files + self.url = "https://github.com/xz-group/AnalogGenie/archive/refs/heads/main.zip" + self.file_format = "zip" + os.makedirs(self.raw_dir, exist_ok=True) + download_file_from_link( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, # names the zip AnalogGenie.zip + file_format=self.file_format, + ) + + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = os.path.join(folder, filename) + + if not os.path.exists(path): + raise FileNotFoundError(f"Downloaded file not found at {path}") + + extract_zip(path, folder) + os.unlink(path) # Remove the zip file after extraction + + # Find the extracted repo folder and move contents up to raw_dir + extracted_folder_name = "" + for item in os.listdir(self.raw_dir): + if os.path.isdir( + os.path.join(self.raw_dir, item) + ) and item.lower().startswith("analoggenie"): + extracted_folder_name = item + break + + if extracted_folder_name: + source_folder = os.path.join(folder, extracted_folder_name) + for file in os.listdir(source_folder): + shutil.move(os.path.join(source_folder, file), folder) + shutil.rmtree(source_folder) + else: + print( + "Warning: extracted folder not found; assuming contents already at raw root." + ) + + def _create_component_vocab(self): + """Build vocabulary of component types. + + Returns + ------- + dict + Mapping from component type to ID. + """ + return { + "capacitor": 0, + "nmos4": 1, + "pmos4": 2, + "resistor": 3, + "unknown": 4, + } + + def _get_incidence_roles( + self, component_type: str, nodes: list[str] + ) -> list[int]: + """Assign pin-role codes per node for a component. + + Parameters + ---------- + component_type : str + Component type identifier. + nodes : list of str + Node names for this component. + + Returns + ------- + list of int + Role codes aligned with nodes. + """ + roles = [] + comp = component_type.lower() + if comp in {"nmos4", "pmos4", "nmos", "pmos", "mos", "mos4"}: + template = [1, 2, 3, 4] # drain, gate, source, bulk + for idx, _ in enumerate(nodes): + roles.append(template[idx] if idx < len(template) else 0) + elif comp in {"bjt", "npn", "pnp"}: + template = [11, 12, 13] # collector, base, emitter + for idx, _ in enumerate(nodes): + roles.append(template[idx] if idx < len(template) else 0) + else: + roles = [0] * len(nodes) + return roles + + def _get_node_feature(self, node_name: str) -> int: + """Map a node name to a categorical feature code. + + Parameters + ---------- + node_name : str + Name of the node in the netlist. + + Returns + ------- + int + Encoded feature ID (0 generic, 1 power, 2 ground, 3 input, 4 output, + 5 bias, 6 gate, 7 drain, 8 source, 9 bulk, 10 clock). + """ + name = node_name.lower() + if any(tok in name for tok in ["vdd", "vcc", "power", "pwr"]): + return 1 + if any(tok in name for tok in ["vss", "gnd", "ground"]): + return 2 + if "clk" in name: + return 10 + if "bias" in name or name.startswith("vb"): + return 5 + if name.startswith(("in", "vin")) or "input" in name: + return 3 + if name.startswith(("out", "vout")) or "output" in name: + return 4 + if "gate" in name: + return 6 + if "drain" in name: + return 7 + if "source" in name: + return 8 + if any(tok in name for tok in ["bulk", "body", "substrate"]): + return 9 + return 0 + + def _parse_params(self, params: list[str]) -> list[float]: + """Parse parameter tokens into a list of floats. + + Parameters + ---------- + params : list of str + Raw parameter tokens from the netlist line. + + Returns + ------- + list of float + Numeric parameters (non-numeric tokens skipped). + """ + values = [] + for tok in params: + if "=" in tok: + tok = tok.split("=")[-1] + tok = tok.strip() + try: + values.append(float(tok)) + except ValueError: + continue + return values + + def len(self): + """Return dataset length. + + Returns + ------- + int + Number of graphs. + """ + if self.slices is None: + return 1 + return self.slices["x"].size(0) - 1 + + def get(self, idx): + """Get an item by index. + + Parameters + ---------- + idx : int + Index of the graph. + + Returns + ------- + torch_geometric.data.Data + Retrieved graph object. + """ + if self.slices is None: + if idx != 0: + raise IndexError( + "Index out of range for single-graph dataset." + ) + return self.data + return super().get(idx) + + def process(self): # noqa: D401 + """Process raw AnalogGenie data into hypergraph tensors.""" + Path(self.processed_dir).mkdir( + parents=True, exist_ok=True + ) # Ensure processed directory exists + data_list = [] + component_vocab = self._create_component_vocab() + + # Discover all .cir files in the raw directory + circuit_files = glob.glob( + os.path.join(self.raw_dir, "Dataset", "*", "*.cir") + ) + if not circuit_files: + empty = AnalogGenieData(x=torch.empty((0, 0), dtype=torch.float)) + empty_slices = {"x": torch.tensor([0])} + torch.save((empty, empty_slices), self.processed_paths[0]) + return + + def parse_spice_netlist(cir_path): + """Parse a SPICE netlist with simple subcircuit expansion. + + Parameters + ---------- + cir_path : str + Path to the circuit file. + + Returns + ------- + list[dict] + Parsed component dictionaries. + """ + components = [] + subckts = {} + with open(cir_path) as f: + lines = f.readlines() + + merged = [] + buffer = "" + for line in lines: + line = line.strip() + if not line or line.startswith(("*", "//")): + continue + if line.startswith("+"): + buffer += " " + line[1:].strip() + continue + if buffer: + merged.append(buffer) + buffer = line + if buffer: + merged.append(buffer) + + idx = 0 + while idx < len(merged): + line = merged[idx] + if line.lower().startswith(".subckt"): + parts = line.split() + name = parts[1] + pins = parts[2:] + body = [] + idx += 1 + while idx < len(merged) and not merged[ + idx + ].lower().startswith(".ends"): + body.append(merged[idx]) + idx += 1 + subckts[name] = {"pins": pins, "body": body} + else: + components.append(line) + idx += 1 + + def parse_component(line): + """Parse a single netlist line into a component dictionary. + + Parameters + ---------- + line : str + Raw netlist line. + + Returns + ------- + dict | None + Parsed component or None. + """ + tokens = line.replace("(", " ").replace(")", " ").split() + if not tokens: + return None + name = tokens[0] + prefix = name[0].lower() + if prefix == "x": # subckt instance + subckt_name = tokens[-1] + nets = tokens[1:-1] + return { + "type": "subckt", + "subckt": subckt_name, + "nets": nets, + "name": name, + } + elif prefix in {"m"}: # mos + nodes = tokens[1:5] + model = tokens[5] if len(tokens) > 5 else "mos" + params = tokens[6:] if len(tokens) > 6 else [] + return { + "type": model, + "nodes": nodes, + "params": params, + "name": name, + } + elif prefix in {"q"}: # bjt + nodes = tokens[1:4] + model = tokens[4] if len(tokens) > 4 else "bjt" + params = tokens[5:] if len(tokens) > 5 else [] + return { + "type": model, + "nodes": nodes, + "params": params, + "name": name, + } + else: + # generic two-terminal + nodes = tokens[1:3] + ctype = tokens[3] if len(tokens) > 3 else tokens[0][0] + params = ( + tokens[4:] + if len(tokens) > 4 + else tokens[3:] + if len(tokens) > 3 + else [] + ) + return { + "type": ctype, + "nodes": nodes, + "params": params, + "name": name, + } + + def expand(instance, prefix=""): + """Recursively expand components and subcircuits. + + Parameters + ---------- + instance : dict + Component or subcircuit instance. + prefix : str, optional + Name prefix for nested instances. + + Returns + ------- + list[dict] + Expanded component list. + """ + if instance.get("type") != "subckt": + return [instance] + sub_name = instance["subckt"] + if sub_name not in subckts: + return [] + mapping = dict( + zip( + subckts[sub_name]["pins"], + instance["nets"], + strict=True, + ) + ) + expanded = [] + for line in subckts[sub_name]["body"]: + comp = parse_component(line) + if comp is None: + continue + if comp.get("type") == "subckt": + remapped = [mapping.get(n, n) for n in comp["nets"]] + comp["nets"] = remapped + expanded += expand( + comp, prefix + instance["name"] + "." + ) + else: + comp["nodes"] = [ + mapping.get(n, n) for n in comp.get("nodes", []) + ] + comp["name"] = ( + prefix + + instance["name"] + + "." + + comp.get("name", "") + ) + expanded.append(comp) + return expanded + + parsed = [] + for line in components: + comp = parse_component(line) + if comp is None: + continue + parsed += expand(comp) + return parsed + + for circuit_file in circuit_files: + components = parse_spice_netlist(circuit_file) + + node_map = {} + hyperedge_index = [] + component_features = [] + incidence_roles = [] + hyperedge_params = [] + + for idx_comp, comp in enumerate(components): + ctype = comp.get("type", "unknown") + nodes = comp.get("nodes", []) + params = comp.get("params", []) + component_features.append( + component_vocab.get( + ctype.lower(), component_vocab["unknown"] + ) + ) + role_codes = self._get_incidence_roles(ctype, nodes) + hyperedge_params.append(self._parse_params(params)) + + for node in nodes: + if node not in node_map: + node_map[node] = len(node_map) + hyperedge_index.append([node_map[node], idx_comp]) + incidence_roles.extend(role_codes[: len(nodes)]) + + # Create node features from node names (power/ground/generic) + x = torch.tensor( + [ + self._get_node_feature(name) + for name, _ in sorted( + node_map.items(), key=lambda kv: kv[1] + ) + ], + dtype=torch.float, + ).view(-1, 1) + + if not hyperedge_index: + # Handle empty graphs + hyperedge_index = torch.empty((2, 0), dtype=torch.long) + incidence_roles = torch.empty((0,), dtype=torch.long) + hyperedge_params_tensor = torch.empty( + (0, 0), dtype=torch.float + ) + else: + hyperedge_index = ( + torch.tensor(hyperedge_index, dtype=torch.long) + .t() + .contiguous() + ) + incidence_roles = torch.tensor( + incidence_roles, dtype=torch.long + ) + max_len = max((len(p) for p in hyperedge_params), default=0) + hyperedge_params_tensor = torch.zeros( + (len(hyperedge_params), max_len), dtype=torch.float + ) + for i, vals in enumerate(hyperedge_params): + if not vals: + continue + end = min(len(vals), max_len) + hyperedge_params_tensor[i, :end] = torch.tensor( + vals[:end], dtype=torch.float + ) + + data = AnalogGenieData(x=x, hyperedge_index=hyperedge_index) + # Add component features as hyperedge attributes + data.hyperedge_attr = torch.tensor( + component_features, dtype=torch.long + ) + data.incidence_roles = incidence_roles + data.hyperedge_params = hyperedge_params_tensor + data_list.append(data) + + if self.pre_filter is not None: + data_list = [data for data in data_list if self.pre_filter(data)] + + if self.pre_transform is not None: + data_list = [self.pre_transform(data) for data in data_list] + + if not data_list: + empty = AnalogGenieData(x=torch.empty((0, 0), dtype=torch.float)) + empty_slices = {"x": torch.tensor([0])} + torch.save((empty, empty_slices), self.processed_paths[0]) + return + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) diff --git a/topobench/data/loaders/hypergraph/aicircuit_dataset_loader.py b/topobench/data/loaders/hypergraph/aicircuit_dataset_loader.py new file mode 100644 index 000000000..f78fd85eb --- /dev/null +++ b/topobench/data/loaders/hypergraph/aicircuit_dataset_loader.py @@ -0,0 +1,60 @@ +"""Loader for AICircuit dataset.""" + +from omegaconf import DictConfig + +from topobench.data.datasets.aicircuit_datasets import AICircuitDataset +from topobench.data.loaders.base import AbstractLoader + + +class AICircuitDatasetLoader(AbstractLoader): + """Load AICircuit dataset with configurable parameters. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - other relevant parameters + cfg : DictConfig, optional + A DictConfig object containing configuration for the dataset itself, + including parameters for the dataset and split_params. + """ + + def __init__( + self, parameters: DictConfig, cfg: DictConfig | None = None + ) -> None: + super().__init__(parameters) + self.cfg = cfg # Store the cfg for dataset initialization + + def load_dataset(self) -> AICircuitDataset: + """Load the AICircuit dataset. + + Returns + ------- + AICircuitDataset + The loaded AICircuit dataset with the appropriate `data_dir`. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + + dataset = self._initialize_dataset() + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self) -> AICircuitDataset: + """Initialize the AICircuit dataset. + + Returns + ------- + AICircuitDataset + The initialized dataset instance. + """ + return AICircuitDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.cfg.parameters if self.cfg is not None else None, + ) diff --git a/topobench/data/loaders/hypergraph/analoggenie_dataset_loader.py b/topobench/data/loaders/hypergraph/analoggenie_dataset_loader.py new file mode 100644 index 000000000..c4ee5e097 --- /dev/null +++ b/topobench/data/loaders/hypergraph/analoggenie_dataset_loader.py @@ -0,0 +1,60 @@ +"""Loader for AnalogGenie dataset.""" + +from omegaconf import DictConfig + +from topobench.data.datasets.analoggenie_datasets import AnalogGenieDataset +from topobench.data.loaders.base import AbstractLoader + + +class AnalogGenieDatasetLoader(AbstractLoader): + """Load AnalogGenie dataset with configurable parameters. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters containing: + - data_dir: Root directory for data + - data_name: Name of the dataset + - other relevant parameters + cfg : DictConfig, optional + A DictConfig object containing configuration for the dataset itself, + including parameters for the dataset and split_params. + """ + + def __init__( + self, parameters: DictConfig, cfg: DictConfig | None = None + ) -> None: + super().__init__(parameters) + self.cfg = cfg # Store the cfg for dataset initialization + + def load_dataset(self) -> AnalogGenieDataset: + """Load the AnalogGenie dataset. + + Returns + ------- + AnalogGenieDataset + The loaded AnalogGenie dataset with the appropriate `data_dir`. + + Raises + ------ + RuntimeError + If dataset loading fails. + """ + + dataset = self._initialize_dataset() + self.data_dir = self.get_data_dir() + return dataset + + def _initialize_dataset(self) -> AnalogGenieDataset: + """Initialize the AnalogGenie dataset. + + Returns + ------- + AnalogGenieDataset + The initialized dataset instance. + """ + return AnalogGenieDataset( + root=str(self.root_data_dir), + name=self.parameters.data_name, + parameters=self.cfg.parameters if self.cfg is not None else None, + ) diff --git a/topobench/data/utils/__init__.py b/topobench/data/utils/__init__.py index 8793f773e..4bb899d40 100644 --- a/topobench/data/utils/__init__.py +++ b/topobench/data/utils/__init__.py @@ -47,6 +47,7 @@ from .io_utils import ( # noqa: E402 download_file_from_drive, # noqa: F401 download_file_from_link, # noqa: F401 + extract_zip, # noqa: F401 load_hypergraph_content_dataset, # noqa: F401 load_hypergraph_pickle_dataset, # noqa: F401 read_ndim_manifolds, # noqa: F401 @@ -59,6 +60,7 @@ "load_hypergraph_content_dataset", "read_us_county_demos", "download_file_from_drive", + "extract_zip", # add function name here ] diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..91aff4ac1 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -3,6 +3,7 @@ import json import os.path as osp import pickle +import zipfile from urllib.parse import parse_qs, urlparse import numpy as np @@ -115,6 +116,20 @@ def download_file_from_link( print("Failed to download the file.") +def extract_zip(zip_file_path, extract_path): + """Extract a zip file to a specified path. + + Parameters + ---------- + zip_file_path : str + Path to the zip archive. + extract_path : str + Destination directory for extraction. + """ + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + zip_ref.extractall(extract_path) + + def read_ndim_manifolds( path, dim, y_val="betti_numbers", slice=None, load_as_graph=False ):