diff --git a/.gitignore b/.gitignore index e618d5221..0b5e6eeb3 100755 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ eggs/ .vscode/ lib/ logs/ +lightning_logs/ lib64/ parts/ sdist/ diff --git a/configs/dataset/graph/a123.yaml b/configs/dataset/graph/a123.yaml new file mode 100644 index 000000000..abd989887 --- /dev/null +++ b/configs/dataset/graph/a123.yaml @@ -0,0 +1,57 @@ +# Config file for loading Bowen et al. mouse auditory cortex calcium imaging dataset. + +# This script downloads and processes the original dataset introduced in: + +# [Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +# organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +# https://doi.org/10.1093/pnasnexus/pgae074 + +# We apply the preprocessing and graph-construction steps defined in this module to obtain +# a representation of neuronal activity suitable for our experiments. + +# Please cite the original paper when using this dataset or any derivatives. + +# Dataset loader config for A123 Cortex M +loader: + _target_: topobench.data.loaders.A123DatasetLoader + parameters: + data_domain: graph + data_type: A123CortexM + data_dir: ${paths.data_dir}/${dataset.loader.parameters.data_domain}/${dataset.loader.parameters.data_type} # Use data_dir from dataset config + data_name: a123_cortex_m # Use data_name from dataset config + num_graphs: 10 + is_undirected: True + num_channels: ${dataset.parameters.num_features} # Use num_features for node feature dim + num_classes: ${dataset.parameters.num_classes} # Use num_classes for output dim + task: ${dataset.parameters.task} # Use task type from dataset config + +# Dataset-specific parameters +parameters: + num_features: 3 + task: classification + specific_task: classification # Current task selection (classification | triangle_classification | triangle_common_neighbors) + num_classes: 9 + loss_type: cross_entropy + monitor_metric: accuracy + task_level: graph + min_neurons: 3 + corr_threshold: 0.2 + + +# Splits +split_params: + learning_setting: inductive + data_split_dir: ${paths.data_dir}/data_splits/${dataset.loader.parameters.data_name} + data_seed: 0 + split_type: random # 'k-fold' or 'random' strategies + k: 10 # for "k-fold" Cross-Validation + train_prop: 0.7 # for "random" strategy splitting + val_prop: 0.15 # for "random" strategy splitting + test_prop: 0.15 # for "random" strategy splitting + +# Dataloader parameters +dataloader_params: + batch_size: 32 + num_workers: 0 + pin_memory: False + diff --git a/env_setup.sh b/env_setup.sh index 8dfd41344..dac3df247 100755 --- a/env_setup.sh +++ b/env_setup.sh @@ -4,7 +4,7 @@ pip install -e '.[all]' # Note that not all combinations of torch and CUDA are available # See https://github.com/pyg-team/pyg-lib to check the configuration that works for you -TORCH="2.3.0" # available options: 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, ... +TORCH="2.1.0" # available options: 2.0.0, 2.1.0, 2.2.0, 2.3.0, 2.4.0, ... CUDA="cpu" # if available, select the CUDA version suitable for your system # available options: cpu, cu102, cu113, cu116, cu117, cu118, cu121, ... pip install torch==${TORCH} --extra-index-url https://download.pytorch.org/whl/${CUDA} diff --git a/test/data/load/test_a123_dataset.py b/test/data/load/test_a123_dataset.py new file mode 100644 index 000000000..422a3b7e1 --- /dev/null +++ b/test/data/load/test_a123_dataset.py @@ -0,0 +1,875 @@ +"""Unit tests for A123 mouse auditory cortex dataset. + +Tests cover: +- Graph-level dataset loading and structure +- Triangle classification task functionality +- Triangle role classification logic +- Feature dimensions and data integrity +- Configuration parameter handling +""" + +import pytest +import torch +import hydra +import networkx as nx +from pathlib import Path +from omegaconf import DictConfig +from torch_geometric.data import Data + +# Import the dataset and loader +from topobench.data.datasets.a123 import A123CortexMDataset, TriangleClassifier +from topobench.data.loaders.graph.a123_loader import A123DatasetLoader + + +def pytest_addoption(parser): + """Add command-line options for pytest. + + Parameters + ---------- + parser : pytest.Parser + Pytest command-line parser. + """ + parser.addoption( + "--specific-task", + action="store", + default=None, + help="Filter tests by specific_task type: classification, triangle_classification, triangle_common_neighbors. " + "If not specified, ALL tests run.", + ) + + +def pytest_collection_modifyitems(config, items): + """Skip tests that don't match the specified task type. + + Parameters + ---------- + config : pytest.Config + Pytest configuration object. + items : list + List of collected test items. + """ + task = config.getoption("--specific-task") + if not task: + return + + # Map task types to test class names + task_mapping = { + "classification": [ + "TestA123GraphDataset", + "TestA123Configuration", + "TestA123DataIntegrity", + ], + "triangle_classification": [ + "TestTriangleClassifier", + "TestTriangleTask", + ], + "triangle_common_neighbors": ["TestTriangleCommonNeighborsTask"], + } + + if task not in task_mapping: + raise ValueError( + f"Invalid --specific-task '{task}'. " + f"Must be one of: {', '.join(task_mapping.keys())}" + ) + + allowed_classes = set(task_mapping[task]) + skipped = 0 + kept = 0 + + for item in items: + # Extract class name from item nodeid + # nodeid format: path/to/file.py::ClassName::test_method_name + class_name = None + if "::" in item.nodeid: + parts = item.nodeid.split("::") + if len(parts) >= 2: + class_name = parts[1] + + if class_name and class_name not in allowed_classes: + item.add_marker( + pytest.mark.skip( + reason=f"Test class {class_name} not for task '{task}'" + ) + ) + skipped += 1 + else: + kept += 1 + + if kept > 0 or skipped > 0: + print(f"\n{'=' * 70}") + print(f"[Task Filter] Task: '{task}'") + print(f" ✓ Kept: {kept} tests - {', '.join(allowed_classes)}") + print(f" ✗ Skipped: {skipped} tests") + print(f"{'=' * 70}\n") + + +class TestA123GraphDataset: + """Test suite for A123 graph-level dataset.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + self.config_path = "../../../configs" + + def test_dataset_loading(self): + """Test basic dataset loading and instantiation.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + # Instantiate loader + loader = hydra.utils.instantiate(cfg.dataset.loader) + # Check loader type using class name to handle module reloading issues + assert type(loader).__name__ == "A123DatasetLoader" or hasattr( + loader, "load_dataset" + ) + + # Load dataset + dataset = loader.load_dataset() + assert dataset is not None + assert hasattr(dataset, "data") + assert isinstance(dataset.data, Data) + + def test_graph_dataset_properties(self): + """Test graph-level dataset has correct properties.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset = loader.load_dataset() + + # Check dataset structure + assert hasattr(dataset, "num_node_features") + assert hasattr(dataset, "num_classes") + # Note: InMemoryDataset doesn't expose num_graphs directly; it uses slices internally + + # Check feature dimensions + assert ( + dataset.num_node_features == 3 + ) # mean_corr, std_corr, noise_diag + assert dataset.num_classes == 9 # 9 frequency bins + + # Check data integrity + assert dataset.data.x is not None + assert dataset.data.edge_index is not None + assert dataset.data.y is not None + + # Check labels are in valid range + assert torch.all(dataset.data.y >= 0) + assert torch.all(dataset.data.y < dataset.num_classes) + + def test_graph_node_features(self): + """Test that node features are correctly structured.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset = loader.load_dataset() + + # Check node features + x = dataset.data.x + assert x.shape[1] == 3 # 3 features per node + assert torch.isfinite(x).all() # No NaN or Inf values + + +class TestA123Configuration: + """Test suite for A123 configuration integrity.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + self.config_path = "../../../configs" + + def test_dataset_parameters(self): + """Test that all required dataset parameters are configured.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + params = cfg.dataset.parameters + + # Check required parameters + assert hasattr(params, "num_features") + assert hasattr(params, "num_classes") + assert hasattr(params, "task") + assert hasattr(params, "min_neurons") + assert hasattr(params, "corr_threshold") + + # Check values are sensible + assert params.num_features == 3 + assert params.num_classes == 9 + assert params.task == "classification" + assert params.min_neurons >= 3 + assert 0.0 <= params.corr_threshold <= 1.0 + + def test_loader_parameters(self): + """Test that loader configuration is correct.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader_cfg = cfg.dataset.loader + + # Check loader configuration + assert hasattr(loader_cfg, "parameters") + params = loader_cfg.parameters + + assert hasattr(params, "data_domain") + assert hasattr(params, "data_type") + assert hasattr(params, "is_undirected") + + # Verify values + assert params.data_domain == "graph" + assert params.data_type == "A123CortexM" + assert params.is_undirected is True + + +class TestA123DataIntegrity: + """Test suite for data integrity checks.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + self.config_path = "../../../configs" + + def test_feature_format(self): + """Test that features are in correct format.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset = loader.load_dataset() + + # Check feature tensor + x = dataset.data.x + assert x.dtype in [torch.float32, torch.float64] + assert torch.isfinite(x).all() + + def test_edge_index_format(self): + """Test that edge indices are correctly formatted.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset = loader.load_dataset() + + # Check edge index + edge_index = dataset.data.edge_index + assert edge_index.dtype == torch.long + assert edge_index.shape[0] == 2 # [2, num_edges] + assert torch.all(edge_index >= 0) # No negative indices + + def test_labels_format(self): + """Test that labels are correctly formatted.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + dataset = loader.load_dataset() + + # Check labels + y = dataset.data.y + assert y.dtype == torch.long + assert torch.all(y >= 0) + assert torch.all(y < dataset.num_classes) + + +class TestTriangleClassifier: + """Test suite for TriangleClassifier helper class.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + self.roles = [ + "core_strong", + "core_medium", + "core_weak", + "bridge_strong", + "bridge_medium", + "bridge_weak", + "isolated_strong", + "isolated_medium", + "isolated_weak", + ] + + def test_triangle_classifier_initialization(self): + """Test TriangleClassifier can be instantiated.""" + classifier = TriangleClassifier(min_weight=0.2) + assert classifier is not None + assert classifier.min_weight == 0.2 + + def test_role_to_label_mapping(self): + """Test that triangle roles map to correct labels (0-6).""" + classifier = TriangleClassifier(min_weight=0.2) + + # Test all 9 roles map to integers 0-8 + for i, role in enumerate(self.roles): + label = classifier._role_to_label(role) + assert isinstance(label, int) + assert 0 <= label <= 8 + # Verify deterministic mapping (same role -> same label) + assert label == classifier._role_to_label(role) + + def test_role_classification_logic(self): + """Test that role classification works with synthetic triangle data.""" + classifier = TriangleClassifier(min_weight=0.2) + + # Create a simple networkx graph with a triangle + G = nx.Graph() + G.add_edge(0, 1, weight=0.8) + G.add_edge(1, 2, weight=0.7) + G.add_edge(0, 2, weight=0.6) + # Add some other nodes connected to all three to test embedding class + G.add_edge(3, 0, weight=0.5) + G.add_edge(3, 1, weight=0.5) + G.add_edge(3, 2, weight=0.5) + + # Test role classification with the graph + nodes = (0, 1, 2) + edge_weights = [0.8, 0.7, 0.6] + + role = classifier._classify_role(G, nodes, edge_weights) + assert role is not None + assert isinstance(role, str) + assert role in self.roles + + def test_triangle_extraction_simple(self): + """Test triangle extraction on a simple graph.""" + classifier = TriangleClassifier(min_weight=0.2) + + # Create a simple graph: complete triangle (3-clique) + # Nodes: 0, 1, 2 + # Edges: (0,1), (1,2), (0,2) + edge_index = torch.tensor( + [[0, 1, 0, 1, 2, 0], [1, 0, 2, 2, 1, 2]] + ) # Undirected edges + edge_weights = torch.tensor([0.9, 0.9, 0.8, 0.8, 0.9, 0.9]) + + triangles = classifier.extract_triangles( + edge_index, edge_weights, num_nodes=3 + ) + + # Should find at least one triangle + assert len(triangles) > 0 + + # Each triangle should have required fields + for tri in triangles: + assert "nodes" in tri + assert "edge_weights" in tri + assert "label" in tri + assert "role" in tri + + # Verify structure + assert len(tri["nodes"]) == 3 + assert len(tri["edge_weights"]) == 3 + assert isinstance(tri["label"], int) + assert 0 <= tri["label"] <= 8 + assert isinstance(tri["role"], str) + + +class TestTriangleTask: + """Test suite for triangle classification task functionality.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + self.config_path = "../../../configs" + + def test_triangle_task_configuration(self): + """Test that triangle task is properly configured.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + # Check triangle task configuration + # specific_task is a string selector, not a nested config + assert hasattr(cfg.dataset.parameters, "specific_task") + specific_task = cfg.dataset.parameters.specific_task + assert isinstance(specific_task, str) + + # Valid tasks include triangle_classification + assert specific_task in [ + "classification", + "triangle_classification", + "triangle_common_neighbors", + ] + + # All tasks use 9 classes + assert cfg.dataset.parameters.num_classes == 9 + + def test_minimal_features_in_config(self): + """Test that triangle features are optimized to minimal set (3D edge weights).""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + # Triangle tasks use 3 node features (edge weights in triangle context) + # This is a fixed feature dimension for all triangle-based tasks + # Verify the base configuration uses 3 features + num_features = cfg.dataset.parameters.num_features + assert num_features == 3, ( + f"Expected 3D features (edge weights only), got {num_features}D. " + "Features should be: [weight_01, weight_02, weight_12]" + ) + + def test_triangle_loader_instantiation(self): + """Test that dataset loader can instantiate with triangle task.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + # Check loader type using class name to handle module reloading issues + assert type(loader).__name__ == "A123DatasetLoader" or hasattr( + loader, "load_dataset" + ) + + # Verify loader has load_dataset method with task_type parameter + assert hasattr(loader, "load_dataset") + assert callable(loader.load_dataset) + + def test_graph_vs_triangle_task_independent(self): + """Test that graph and triangle tasks are independent. + + Graph task should always work. Triangle task may require + prior processing, but shouldn't affect graph task. + """ + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + + # Graph task should always work + graph_dataset = loader.load_dataset() + assert graph_dataset is not None + assert graph_dataset.num_classes == 9 + + # Triangle task may fail if not processed yet, but shouldn't crash loader + try: + triangle_dataset = loader.load_dataset() + if triangle_dataset is not None: + # If triangle dataset loaded, verify it has expected properties + assert hasattr(triangle_dataset, "data") + except FileNotFoundError: + # Expected if triangle processing hasn't been run + pass + + +class TestTriangleCommonNeighborsTask: + """Test suite for triangle common-neighbors task (TDL-focused).""" + + @pytest.fixture(autouse=True) + def setup(self): + """Setup test environment.""" + hydra.core.global_hydra.GlobalHydra.instance().clear() + self.config_path = "../../../configs" + + def test_triangle_common_task_configuration(self): + """Test that triangle common-neighbors task is configured.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + # Check that specific_task can be set to triangle_common_neighbors + # It's a string selector in parameters, not a nested config structure + assert hasattr(cfg.dataset.parameters, "specific_task") + specific_task = cfg.dataset.parameters.specific_task + + # Verify it's a valid task selector + assert isinstance(specific_task, str) + assert specific_task in [ + "classification", + "triangle_classification", + "triangle_common_neighbors", + ] + + # All tasks use 9 classes (unified output) + assert cfg.dataset.parameters.num_classes == 9 + + def test_triangle_common_loader_instantiation(self): + """Test that loader can instantiate with triangle common-neighbors task.""" + with hydra.initialize( + version_base="1.3", config_path=self.config_path, job_name="test" + ): + cfg = hydra.compose( + config_name="run.yaml", + overrides=[ + "dataset=graph/a123", + "model=graph/gat", + "paths=test", + ], + return_hydra_config=True, + ) + + loader = hydra.utils.instantiate(cfg.dataset.loader) + # Check loader has load_dataset method + assert hasattr(loader, "load_dataset") + assert callable(loader.load_dataset) + + def test_triangle_common_task_creation_synthetic(self): + """Test triangle common-neighbors task creation on synthetic graph.""" + classifier = TriangleClassifier(min_weight=0.2) + + # Create synthetic graph with known structure + # Nodes: 0, 1, 2, 3, 4 + # Triangles: (0,1,2) with common neighbor 3, (1,2,3) with common neighbor 4 + edge_list = [ + (0, 1), + (0, 2), + (1, 2), # Triangle 0-1-2 + (1, 3), + (2, 3), + (0, 3), # Add node 3 as common neighbor + (1, 4), + (2, 4), + (3, 4), # Add node 4 as common neighbor + ] + + edge_index_list = [] + for u, v in edge_list: + edge_index_list.append([u, v]) + edge_index_list.append([v, u]) # Undirected + + edge_index = ( + torch.tensor(edge_index_list, dtype=torch.long).t().contiguous() + ) + edge_weights = torch.ones(edge_index.shape[1]) + + # Extract triangles + triangles = classifier.extract_triangles( + edge_index, edge_weights, num_nodes=5 + ) + + # Verify triangles were found + assert len(triangles) > 0 + + # Now simulate creating CN features for each triangle + # Build graph to compute common neighbors + G = nx.Graph() + G.add_nodes_from(range(5)) + for i in range(edge_index.shape[1]): + u = int(edge_index[0, i].item()) + v = int(edge_index[1, i].item()) + G.add_edge(u, v) + + # For each triangle, compute common neighbors + for tri in triangles: + a, b, c = tri["nodes"] + common = ( + set(G.neighbors(a)) & set(G.neighbors(b)) & set(G.neighbors(c)) + ) - {a, b, c} + num_common = len(common) + + # Features: node degrees (structural, no weights) + deg_a = G.degree(a) + deg_b = G.degree(b) + deg_c = G.degree(c) + + # Verify features are reasonable + assert deg_a > 0 and deg_b > 0 and deg_c > 0 + assert num_common >= 0 + + def test_triangle_common_features_are_structural(self): + """Test that CN task features are purely structural (node degrees).""" + # Create a simple triangle with known degrees + edge_index = torch.tensor( + [[0, 0, 1, 2], [1, 2, 2, 0]], dtype=torch.long + ) # Triangle 0-1-2 + extra edge 0-1 + edge_weights = torch.ones(edge_index.shape[1]) + + classifier = TriangleClassifier(min_weight=0.2) + triangles = classifier.extract_triangles( + edge_index, edge_weights, num_nodes=3 + ) + + # Build graph + G = nx.Graph() + G.add_nodes_from(range(3)) + for i in range(edge_index.shape[1]): + u = int(edge_index[0, i].item()) + v = int(edge_index[1, i].item()) + G.add_edge(u, v) + + # For triangle, extract degree features + for tri in triangles: + a, b, c = tri["nodes"] + deg_a = G.degree(a) + deg_b = G.degree(b) + deg_c = G.degree(c) + + tri_feats = torch.tensor( + [deg_a, deg_b, deg_c], dtype=torch.float32 + ) + + # Features should be non-negative integers (degrees) + assert tri_feats.shape == (3,) + assert torch.all(tri_feats >= 0) + # Degrees should be integers stored as floats + assert torch.allclose(tri_feats, tri_feats.round()) + + def test_triangle_common_label_semantics(self): + """Test that CN task labels represent common neighbor counts.""" + # Create graph where we know common neighbor counts + # 4 nodes: (0,1,2) form triangle with no external connections (CN=0) + # Add node 3 connected to all three: (0,1,2) will have CN=1 + edges_no_cn = torch.tensor( + [[0, 0, 1], [1, 2, 2]], dtype=torch.long + ) # Triangle 0-1-2 + edge_weights_no_cn = torch.ones(edges_no_cn.shape[1]) + + classifier = TriangleClassifier(min_weight=0.2) + triangles_no_cn = classifier.extract_triangles( + edges_no_cn, edge_weights_no_cn, num_nodes=3 + ) + + # Build graph + G = nx.Graph() + G.add_nodes_from(range(3)) + for i in range(edges_no_cn.shape[1]): + u = int(edges_no_cn[0, i].item()) + v = int(edges_no_cn[1, i].item()) + G.add_edge(u, v) + + # Verify CN count for triangle with no external neighbors + for tri in triangles_no_cn: + a, b, c = tri["nodes"] + common = ( + set(G.neighbors(a)) & set(G.neighbors(b)) & set(G.neighbors(c)) + ) - {a, b, c} + assert len(common) == 0 # No common neighbors + + def test_triangle_common_vs_role_independence(self): + """Test that CN task is independent of role classification task.""" + # Both tasks should work without interference + # CN task focuses on structural degree measures + # Role task focuses on embedding + weight patterns + + # Create a simple graph + edge_index = torch.tensor( + [[0, 0, 1, 1, 2, 0], [1, 2, 2, 0, 0, 2]], dtype=torch.long + ) + edge_weights = torch.tensor([0.8, 0.7, 0.6, 0.6, 0.7, 0.8]) + + classifier = TriangleClassifier(min_weight=0.2) + triangles = classifier.extract_triangles( + edge_index, edge_weights, num_nodes=3 + ) + + # From triangle classifier, we get roles (based on weights + embedding) + for tri in triangles: + role = tri["role"] + label = tri["label"] + + # Role should be one of the 7 types + assert isinstance(role, str) + assert 0 <= label <= 6 + + # CN features would be independent (just degrees) + # So two triangles with different roles could have same degree features + assert "strong" in role or "medium" in role or "weak" in role + + def test_triangle_common_edge_cases(self): + """Test CN task handles edge cases gracefully.""" + classifier = TriangleClassifier(min_weight=0.2) + + # Empty graph + empty_edge_index = torch.empty((2, 0), dtype=torch.long) + empty_weights = torch.empty((0,)) + triangles_empty = classifier.extract_triangles( + empty_edge_index, empty_weights, num_nodes=0 + ) + assert len(triangles_empty) == 0 + + # Graph with no triangles (just edges) + linear_edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long) + linear_weights = torch.ones(linear_edge_index.shape[1]) + triangles_linear = classifier.extract_triangles( + linear_edge_index, linear_weights, num_nodes=3 + ) + assert len(triangles_linear) == 0 # No triangles in a path + + # Single triangle (minimal case) + single_tri_edge_index = torch.tensor( + [[0, 0, 1], [1, 2, 2]], dtype=torch.long + ) + single_tri_weights = torch.ones(single_tri_edge_index.shape[1]) + triangles_single = classifier.extract_triangles( + single_tri_edge_index, single_tri_weights, num_nodes=3 + ) + assert len(triangles_single) == 1 + assert len(triangles_single[0]["nodes"]) == 3 + + +if __name__ == "__main__": + """Run tests for each task type with clear output.""" + import sys + + tasks = [ + "classification", + "triangle_classification", + "triangle_common_neighbors", + ] + + print("\n" + "=" * 80) + print("RUNNING A123 DATASET TESTS FOR ALL TASK TYPES") + print("=" * 80) + + results = {} + + for task in tasks: + print(f"\n{'-' * 80}") + print(f"Running tests for: {task}") + print(f"{'-' * 80}\n") + + # Use pytest.main programmatically + exit_code = pytest.main( + [ + __file__, + f"--specific-task={task}", + "-v", + "--tb=short", + ] + ) + + results[task] = exit_code + + # Summary + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + + for task, returncode in results.items(): + status = "✓ PASSED" if returncode == 0 else "✗ FAILED" + print(f"{status:12} - {task}") + + print("=" * 80 + "\n") + + # Exit with failure if any task failed + sys.exit(max(results.values())) diff --git a/test/data/utils/test_io_utils.py b/test/data/utils/test_io_utils.py index 85c09e9c8..941fd4a74 100644 --- a/test/data/utils/test_io_utils.py +++ b/test/data/utils/test_io_utils.py @@ -1,5 +1,9 @@ """Tests for the io_utils module.""" +import os +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch import pytest from topobench.data.utils.io_utils import * @@ -20,3 +24,334 @@ def test_get_file_id_from_url(): with pytest.raises(ValueError): get_file_id_from_url(url_wrong) + + +class TestDownloadFileFromLink: + """Test suite for download_file_from_link function.""" + + @pytest.fixture + def temp_dir(self): + """Create temporary directory for test outputs. + + Returns + ------- + str + Path to temporary directory. + """ + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + @pytest.fixture + def mock_response(self): + """Create mock response object. + + Returns + ------- + MagicMock + Mock response object with status code and headers. + """ + response = MagicMock() + response.status_code = 200 + response.headers = {"content-length": "5242880"} # 5 MB + response.elapsed.total_seconds.return_value = 1.0 + return response + + def test_download_success_with_progress(self, temp_dir, mock_response): + """Test successful download with progress reporting. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + # Setup mock chunks (5MB total in 1MB chunks) + chunk_data = [b"x" * (1024 * 1024) for _ in range(5)] + mock_response.iter_content.return_value = chunk_data + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # Verify file was created and has correct size + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 5 * 1024 * 1024 + + def test_download_creates_directory_if_not_exists(self, temp_dir): + """Test that download creates directory structure. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + nested_dir = os.path.join(temp_dir, "nested", "path") + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=nested_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(nested_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.isdir(nested_dir) + + def test_download_http_error(self, temp_dir): + """Test handling of HTTP error responses. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 404 + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/nonexistent.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should not be created on HTTP error + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert not os.path.exists(output_file) + + def test_download_timeout_retry(self, temp_dir): + """Test retry logic on timeout. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + # First call times out, second succeeds + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-length": "1024"} + mock_response_success.elapsed.total_seconds.return_value = 0.5 + mock_response_success.iter_content.return_value = [b"x" * 1024] + + mock_get.side_effect = [ + requests.exceptions.Timeout("Connection timed out"), + mock_response_success, + ] + + with patch("time.sleep"): # Mock sleep to speed up test + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=3, + ) + + # File should be created on successful retry + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert mock_get.call_count == 2 + + def test_download_exhausts_retries(self, temp_dir): + """Test that exception is raised after all retries exhausted. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + import requests + + with patch("requests.get") as mock_get: + mock_get.side_effect = requests.exceptions.Timeout( + "Connection timed out" + ) + + with patch("time.sleep"): + with pytest.raises(requests.exceptions.Timeout): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=2, + ) + + # Verify retries were attempted + assert mock_get.call_count == 2 + + def test_download_with_different_formats(self, temp_dir, mock_response): + """Test download with different file formats. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + formats = ["zip", "tar", "tar.gz"] + + with patch("requests.get", return_value=mock_response): + for fmt in formats: + download_file_from_link( + file_link="http://example.com/dataset", + path_to_save=temp_dir, + dataset_name=f"test_dataset_{fmt.replace('.', '_')}", + file_format=fmt, + timeout=60, + retries=1, + ) + + # Verify all files were created with correct extensions + for fmt in formats: + output_file = os.path.join( + temp_dir, f"test_dataset_{fmt.replace('.', '_')}.{fmt}" + ) + assert os.path.exists(output_file) + + def test_download_empty_chunks(self, temp_dir): + """Test handling of empty chunks in response. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"content-length": "1024"} + mock_response.elapsed.total_seconds.return_value = 1.0 + # Include empty chunks (should be skipped) + mock_response.iter_content.return_value = [ + b"x" * 512, + b"", # Empty chunk + b"y" * 512, + b"", # Another empty chunk + ] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + # File should contain only non-empty chunks + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_unknown_size(self, temp_dir): + """Test download when content-length header is missing. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + """ + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {} # No content-length header + mock_response.elapsed.total_seconds.return_value = 0.5 + mock_response.iter_content.return_value = [b"x" * 1024] + + with patch("requests.get", return_value=mock_response): + download_file_from_link( + file_link="http://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=60, + retries=1, + ) + + output_file = os.path.join(temp_dir, "test_dataset.tar.gz") + assert os.path.exists(output_file) + assert os.path.getsize(output_file) == 1024 + + def test_download_ssl_verification_disabled(self, temp_dir, mock_response): + """Test that SSL verification can be disabled. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + download_file_from_link( + file_link="https://example.com/dataset.tar.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + verify=False, + timeout=60, + retries=1, + ) + + # Verify requests.get was called with verify=False + mock_get.assert_called_once() + assert mock_get.call_args[1]["verify"] is False + + def test_download_custom_timeout(self, temp_dir, mock_response): + """Test that custom timeout is used. + + Parameters + ---------- + temp_dir : str + Temporary directory path. + mock_response : MagicMock + Mock response object. + """ + mock_response.iter_content.return_value = [b"test content"] + + with patch("requests.get", return_value=mock_response) as mock_get: + custom_timeout = 120 # 2 minutes per chunk + download_file_from_link( + file_link="https://github.com/aidos-lab/mantra/releases/download/{version}/2_manifolds.json.gz", + path_to_save=temp_dir, + dataset_name="test_dataset", + file_format="tar.gz", + timeout=custom_timeout, + retries=1, + ) + + # Verify requests.get was called with correct timeout + mock_get.assert_called_once() + assert mock_get.call_args[1]["timeout"] == (30, custom_timeout) \ No newline at end of file diff --git a/test/pipeline/test_pipeline.py b/test/pipeline/test_pipeline.py index 785987159..9b02af884 100644 --- a/test/pipeline/test_pipeline.py +++ b/test/pipeline/test_pipeline.py @@ -1,15 +1,26 @@ """Test pipeline for a particular dataset and model.""" +import sys +import os +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + import hydra from test._utils.simplified_pipeline import run +import argparse +class TestPipeline: + """Test pipeline for a particular dataset and model. -DATASET = "graph/MUTAG" # ADD YOUR DATASET HERE -MODELS = ["graph/gcn", "cell/topotune", "simplicial/topotune"] # ADD ONE OR SEVERAL MODELS OF YOUR CHOICE HERE - + Parameters + ---------- + args : argparse.Namespace + Command-line arguments containing dataset and models configuration. + """ -class TestPipeline: - """Test pipeline for a particular dataset and model.""" + def __init__(self, args): + self.dataset = args.dataset + self.models = args.models if isinstance(args.models, list) else [args.models] def setup_method(self): """Setup method.""" @@ -18,12 +29,12 @@ def setup_method(self): def test_pipeline(self): """Test pipeline.""" with hydra.initialize(config_path="../../configs", job_name="job"): - for MODEL in MODELS: + for MODEL in self.models: cfg = hydra.compose( config_name="run.yaml", overrides=[ f"model={MODEL}", - f"dataset={DATASET}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION + f"dataset={self.dataset}", # IF YOU IMPLEMENT A LARGE DATASET WITH AN OPTION TO USE A SLICE OF IT, ADD BELOW THE CORRESPONDING OPTION "trainer.max_epochs=2", "trainer.min_epochs=1", "trainer.check_val_every_n_epoch=1", @@ -32,4 +43,28 @@ def test_pipeline(self): ], return_hydra_config=True ) - run(cfg) \ No newline at end of file + run(cfg) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Pipeline") + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Path to data root directory", + choices = ['graph/MUTAG', 'graph/a123'] # ADD YOUR DATASET HERE + ) + parser.add_argument( + "--models", + type=str, + nargs='+', + required=False, + default=["graph/gcn"], + help="Model(s) to use in the pipeline", + ) + args = parser.parse_args() + + test_pipeline = TestPipeline(args) + test_pipeline.setup_method() + test_pipeline.test_pipeline() \ No newline at end of file diff --git a/topobench/data/datasets/a123.py b/topobench/data/datasets/a123.py new file mode 100644 index 000000000..c181c94a3 --- /dev/null +++ b/topobench/data/datasets/a123.py @@ -0,0 +1,775 @@ +""" +Dataset class for the Bowen et al. mouse auditory cortex calcium imaging dataset. + +This script downloads and processes the original dataset introduced in: + +[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +https://doi.org/10.1093/pnasnexus/pgae074 + +We apply the preprocessing and graph-construction steps defined in this module to obtain +a representation of neuronal activity suitable for our experiments. + +Please cite the original paper when using this dataset or any derivatives. +""" + +import os +import os.path as osp +import shutil +from typing import ClassVar + +import networkx as nx +import numpy as np +import pandas as pd +import scipy.io +import torch +from omegaconf import DictConfig +from torch_geometric.data import Data, InMemoryDataset, extract_zip +from torch_geometric.io import fs +from torch_geometric.utils import to_undirected + +from topobench.data.utils import download_file_from_link +from topobench.data.utils.io_utils import collect_mat_files, process_mat +from topobench.data.utils.triangle_classifier import ( + TriangleClassifier as BaseTriangleClassifier, +) + + +class TriangleClassifier(BaseTriangleClassifier): + """A123-specific triangle classifier for auditory cortex data. + + Extends TriangleClassifier with domain-specific role classification based on: + - Embedding class: determined by number of common neighbors (core, bridge, isolated) + - Weight class: determined by edge correlation strengths (strong, medium, weak) + + This produces 9 classes combining embedding × weight classes. + + Parameters + ---------- + min_weight : float, optional + Minimum correlation to consider as edge, by default 0.2. + """ + + def __init__(self, min_weight: float = 0.2): + """Initialize A123 triangle classifier. + + Parameters + ---------- + min_weight : float, optional + Minimum correlation to consider as edge, by default 0.2 + """ + super().__init__(min_weight=min_weight) + + def _classify_role( + self, G: nx.Graph, nodes: tuple, edge_weights: list + ) -> str: + """Classify role of triangle based on edge weights and embedding. + + Parameters + ---------- + G : nx.Graph + The correlation graph. + nodes : tuple + Three node indices forming the triangle. + edge_weights : list + Three edge weights. + + Returns + ------- + str + Role string in format "{embedding_class}_{weight_class}". + """ + a, b, c = nodes + + # Edge weight class + w_sorted = sorted(edge_weights) + if all(w > 0.5 for w in edge_weights): + weight_class = "strong" + elif w_sorted[0] < 0.3: + weight_class = "weak" + else: + weight_class = "medium" + + # Embedding class: how many other nodes connect to all 3 triangle nodes + common = len( + set(G.neighbors(a)) + & set(G.neighbors(b)) + & set(G.neighbors(c)) - {a, b, c} + ) + + if common >= 3: + embedding_class = "core" + elif common == 0: + embedding_class = "isolated" + else: + embedding_class = "bridge" + + return f"{embedding_class}_{weight_class}" + + def _role_to_label(self, role_str: str) -> int: + """Convert role string to integer label. + + All 9 combinations of embedding class × weight class are supported: + + Embedding classes: core (high common neighbors), bridge, isolated (low common neighbors) + Weight classes: strong (high correlation), medium, weak (low correlation) + + Parameters + ---------- + role_str : str + Role string (e.g., "core_strong"). + + Returns + ------- + int + Label (0-8), mapping all 9 embedding × weight class combinations. + """ + roles = { + # Core triangles (many common neighbors) + "core_strong": 0, + "core_medium": 1, + "core_weak": 2, + # Bridge triangles (some common neighbors) + "bridge_strong": 3, + "bridge_medium": 4, + "bridge_weak": 5, + # Isolated triangles (few/no common neighbors) + "isolated_strong": 6, + "isolated_medium": 7, + "isolated_weak": 8, + } + return roles.get(role_str, 8) # Default to isolated_weak if unknown + + +class A123CortexMDataset(InMemoryDataset): + """A1 and A2/3 mouse auditory cortex dataset. + + Loads neural correlation data from mouse auditory cortex regions. Supports + multiple benchmark tasks: + + 1. Graph Classification: Predict frequency bin (0-8) from graph structure + 2. Triangle Classification: Classify topological role of triangles (motifs) + + Parameters + ---------- + root : str + Root directory where the dataset will be saved. + name : str + Name of the dataset. + parameters : DictConfig + Configuration parameters for the dataset including corr_threshold, + n_bins, min_neurons, and optional triangle_task settings. + + Attributes + ---------- + URLS : dict + Dictionary containing the URLs for downloading the dataset. + FILE_FORMAT : dict + Dictionary containing the file formats for the dataset. + RAW_FILE_NAMES : dict + Dictionary containing the raw file names for the dataset. + """ + + URLS: ClassVar = { + "Auditory cortex data": "https://gcell.umd.edu/data/Auditory_cortex_data.zip", + } + + FILE_FORMAT: ClassVar = { + "Auditory cortex data": "zip", + } + + RAW_FILE_NAMES: ClassVar = {} + + def __init__( + self, + root: str, + name: str, + parameters: DictConfig, + ) -> None: + self.name = name + self.parameters = parameters + + # defensive parameter access with sensible defaults + try: + self.corr_threshold = float(parameters.get("corr_threshold", 0.2)) + except Exception: + self.corr_threshold = float( + getattr(parameters, "corr_threshold", 0.2) + ) + + try: + self.n_bins = int(parameters.get("n_bins", 9)) + except Exception: + self.n_bins = int(getattr(parameters, "n_bins", 9)) + + try: + self.min_neurons = int(parameters.get("min_neurons", 8)) + except Exception: + self.min_neurons = int(getattr(parameters, "min_neurons", 8)) + + # Task type from parameters (classification, triangle_classification, or triangle_common_neighbors) + try: + self.task_type = str( + parameters.get("specific_task", "classification") + ) + except Exception: + self.task_type = str( + getattr(parameters, "specific_task", "classification") + ) + + self.session_map = {} + super().__init__( + root, + ) + + out = fs.torch_load(self.processed_paths[0]) + assert len(out) == 3 or len(out) == 4 + if len(out) == 3: # Backward compatibility. + data, self.slices, self.sizes = out + data_cls = Data + else: + data, self.slices, self.sizes, data_cls = out + + if not isinstance(data, dict): # Backward compatibility. + self.data = data + else: + self.data = data_cls.from_dict(data) + + # For this dataset we don't assume the internal _data is a torch_geometric Data + # (this dataset exposes helper methods to construct subgraphs on demand). + + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.force_reload={self.force_reload})" + + @property + def raw_dir(self) -> str: + """Path to the raw directory of the dataset. + + Returns + ------- + str + Path to the raw directory. + """ + return osp.join(self.root, self.name, "raw") + + @property + def processed_dir(self) -> str: + """Path to the processed directory of the dataset. + + Returns + ------- + str + Path to the processed directory. + """ + return osp.join(self.root, self.name, "processed") + + @property + def raw_file_names(self) -> list[str]: + """Return the raw file names for the dataset. + + Returns + ------- + list[str] + List of raw file names. + """ + return ["Auditory cortex data/"] + + @property + def processed_file_names(self) -> str: + """Return the processed file name for the dataset. + + Returns + ------- + str + Processed file name. + """ + return "data.pt" + + def download(self) -> None: + """Download the dataset from a URL and extract to the raw directory.""" + # Download data from the source + dataset_key = "Auditory cortex data" + self.url = self.URLS[dataset_key] + self.file_format = self.FILE_FORMAT[dataset_key] + + # Use self.name as the downloadable dataset name + download_file_from_link( + file_link=self.url, + path_to_save=self.raw_dir, + dataset_name=self.name, + file_format=self.file_format, + verify=False, + timeout=60, # 60 seconds per chunk read timeout + retries=3, # Retry up to 3 times + ) + + # Extract zip file + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + + # Move files from extracted "Auditory cortex data/" directory to raw_dir + downloaded_dir = osp.join(folder, self.name) + if osp.exists(downloaded_dir): + for file in os.listdir(downloaded_dir): + src = osp.join(downloaded_dir, file) + dst = osp.join(folder, file) + if osp.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.move(src, dst) + # Delete the extracted top-level directory + shutil.rmtree(downloaded_dir) + self.data_dir = folder + + @staticmethod + def extract_samples(data_dir: str, n_bins: int, min_neurons: int = 8): + """Extract subgraph samples from raw .mat files. + + Parameters + ---------- + data_dir : str + Directory containing the raw .mat files. + n_bins : int + Number of frequency bins to use for binning. + min_neurons : int, optional + Minimum number of neurons required per sample. Defaults to 8. + + Returns + ------- + pd.DataFrame + DataFrame containing extracted samples with columns for + session_file, session_id, layer, bf_bin, neuron_indices, + corr, and noise_corr. + """ + mat_files = collect_mat_files(data_dir) + + samples = [] + session_id = 0 + for f in mat_files: + print(f"Processing session {session_id}: {os.path.basename(f)}") + mt = process_mat(scipy.io.loadmat(f)) + for layer in range(1, 6): + scorrs = np.array(mt["selectZCorrInfo"]["SigCorrs"]) + ncorrs = np.array(mt["selectZCorrInfo"]["NoiseCorrsTrial"]) + bfvals = np.array(mt["BFInfo"][layer]["BFval"]).ravel() + if scorrs.size == 0 or bfvals.size == 0: + continue + + bin_ids = bfvals.astype(int) + + for bin_idx in range(n_bins): + sel = np.where(bin_ids == bin_idx)[0] + if len(sel) < min_neurons: + continue + subcorr = scorrs[np.ix_(sel, sel)] + samples.append( + { + "session_file": f, + "session_id": session_id, + "layer": layer, + "bf_bin": int(bin_idx), + "neuron_indices": sel.tolist(), + "corr": subcorr.astype(float), + "noise_corr": ncorrs[np.ix_(sel, sel)].astype( + float + ), + } + ) + session_id += 1 + + samples = pd.DataFrame(samples) + return samples + + def _sample_to_pyg_data( + self, sample: dict, threshold: float = 0.2 + ) -> Data: + """Convert a sample dictionary to a PyTorch Geometric Data object. + + Converts correlation matrices to graph representation with node features + and edges for graph-level classification tasks. + + Parameters + ---------- + sample : dict + Sample dictionary containing 'corr', 'noise_corr', 'session_id', + 'layer', and 'bf_bin' keys. + threshold : float, optional + Correlation threshold for creating edges. Defaults to 0.2. + + Returns + ------- + torch_geometric.data.Data + Data object with node features [mean_corr, std_corr, noise_diag], + edges from thresholded correlation, and label y as integer bf_bin. + """ + corr = np.asarray(sample.get("corr")) + if corr.ndim != 2 or corr.size == 0: + # empty placeholder graph + x = torch.zeros((0, 3), dtype=torch.float) + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, 1), dtype=torch.float) + else: + n = corr.shape[0] + # sanitize + corr = np.nan_to_num(corr) + + mean_corr = corr.mean(axis=1) + std_corr = corr.std(axis=1) + noise_diag = np.zeros(n) + if "noise_corr" in sample and sample["noise_corr"] is not None: + nc = np.asarray(sample["noise_corr"]) + if nc.shape == corr.shape: + noise_diag = np.diag(nc) + + x_np = np.vstack([mean_corr, std_corr, noise_diag]).T + x = torch.tensor(x_np, dtype=torch.float) + + # build edges from thresholded correlation (upper triangle) + adj = (corr >= threshold).astype(int) + iu = np.triu_indices(n, k=1) + sel = np.where(adj[iu] == 1)[0] + if sel.size == 0: + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0, 1), dtype=torch.float) + else: + rows = iu[0][sel] + cols = iu[1][sel] + edge_index_np = np.vstack([rows, cols]) + edge_index = torch.tensor(edge_index_np, dtype=torch.long) + # make undirected + edge_index = to_undirected(edge_index) + # edge_attr: corresponding corr weights (for both directions, if made undirected) + weights = corr[rows, cols] + weights = ( + np.repeat(weights, 2) + if edge_index.size(1) == weights.size * 2 + else weights + ) + edge_attr = torch.tensor( + weights.reshape(-1, 1), dtype=torch.float + ) + + y = torch.tensor([int(sample.get("bf_bin", -1))], dtype=torch.long) + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + # attach metadata + data.session_id = int(sample.get("session_id", -1)) + data.layer = int(sample.get("layer", -1)) + return data + + def _extract_triangles_from_graphs(self) -> list: + """Extract raw triangle data from all graphs with NetworkX representations. + + Returns a list of dicts, each containing graph metadata and triangle info. + + Returns + ------- + list of dict + Each dict has keys: + - 'graph_idx': index of source graph + - 'tri': triangle dict from classifier (with nodes, edge_weights, role, label) + - 'G': NetworkX graph object (for structural queries) + - 'num_nodes': number of nodes in graph + """ + import time + + classifier = TriangleClassifier(min_weight=self.corr_threshold) + raw_triangles = [] + + print("[A123] Starting triangle extraction from graphs...") + + num_graphs = len(self) + start_time = time.time() + + for graph_idx in range(num_graphs): + if graph_idx % 10 == 0 and graph_idx > 0: + elapsed = time.time() - start_time + avg_time = elapsed / graph_idx + remaining = avg_time * (num_graphs - graph_idx) + print( + f"[A123] Processed {graph_idx}/{num_graphs} graphs " + f"({int(elapsed)}s elapsed, ~{int(remaining)}s remaining)..." + ) + + # Get individual data object using get() method + data = self.get(graph_idx) + + # Skip graphs with no edges + if data.edge_index.shape[1] == 0: + continue + + # Build NetworkX graph for structural queries + num_nodes = ( + data.x.shape[0] + if hasattr(data, "x") and data.x is not None + else 0 + ) + + try: + # Build NetworkX graph once (reuse in both enumeration and classification) + G = nx.Graph() + G.add_nodes_from(range(num_nodes)) + for i in range(data.edge_index.shape[1]): + u = int(data.edge_index[0, i].item()) + v = int(data.edge_index[1, i].item()) + w = ( + float(data.edge_attr[i].item()) + if data.edge_attr is not None + else 1.0 + ) + G.add_edge(u, v, weight=w) + + # Enumerate triangles and classify them using separate methods + triangles = classifier.enumerate_triangles(G) + triangle_data = classifier.classify_and_weight_triangles( + triangles, G + ) + + except Exception as e: + print( + f"[A123] Warning: Could not extract triangles for graph {graph_idx}: {e}" + ) + import traceback + + traceback.print_exc() + continue + + # Store raw triangle data with graph context + for tri in triangle_data: + raw_triangles.append( # noqa: PERF401 (appending dict, not extending) + { + "graph_idx": graph_idx, + "tri": tri, + "G": G, + "num_nodes": num_nodes, + } + ) + + elapsed = time.time() - start_time + print( + f"[A123] Triangle extraction completed in {int(elapsed)}s, found {len(raw_triangles)} triangles" + ) + return raw_triangles + + def create_triangle_classification_task(self) -> list: + """Create triangle-level classification dataset from graph-level data. + + Extracts all triangles from each graph and creates a new dataset where + each sample is a triangle classified by its topological role. Features + are purely topological (edge weights only) - independent of original + node properties or frequency information. + + Uses 9 classes based on all combinations of embedding class and weight class: + - 0: core_strong (high common neighbors, strong correlation) + - 1: core_medium (high common neighbors, medium correlation) + - 2: core_weak (high common neighbors, weak correlation) + - 3: bridge_strong (some common neighbors, strong correlation) + - 4: bridge_medium (some common neighbors, medium correlation) + - 5: bridge_weak (some common neighbors, weak correlation) + - 6: isolated_strong (few common neighbors, strong correlation) + - 7: isolated_medium (few common neighbors, medium correlation) + - 8: isolated_weak (few common neighbors, weak correlation) + + Returns + ------- + list of torch_geometric.data.Data + Triangle-level samples with 3D edge weight features and role labels (0-8). + """ + raw_triangles = self._extract_triangles_from_graphs() + triangle_data_list = [] + + print("[A123] Creating triangle classification task...") + + for item in raw_triangles: + tri = item["tri"] + graph_idx = item["graph_idx"] + + # Topological features only: edge weights + tri_edge_weights = torch.tensor( + tri["edge_weights"], dtype=torch.float32 + ) # (3,) + + # Use the role label (0-8) for all 9 triangle topological classes + label = tri["label"] # Now 0-8 from _role_to_label() + + # Create data object for this triangle + tri_data = Data( + x=tri_edge_weights.unsqueeze(0), # (1, 3) - edge weights only + y=torch.tensor(label, dtype=torch.long), + nodes=torch.tensor(tri["nodes"], dtype=torch.long), + role=tri["role"], + graph_idx=graph_idx, + ) + + triangle_data_list.append(tri_data) + + print(f"[A123] Created {len(triangle_data_list)} triangle samples") + return triangle_data_list + + def create_triangle_common_neighbors_task(self) -> list: + """Create triangle-level dataset where label is the number of common neighbours. + + For each triangle (a,b,c) we compute: + - feature: the degrees of the three nodes (structural, no weights) + - label: number of nodes that are neighbours to all three (common neighbours) + Classes: 0-7 neighbors map to classes 0-7, 8+ neighbors map to class 8 + + Returns + ------- + list of torch_geometric.data.Data + Each Data contains x (1,3) degrees, y (scalar) common-neighbour count (0-8), + nodes (3,), role (str) optionally, and graph_idx metadata. + """ + raw_triangles = self._extract_triangles_from_graphs() + triangle_data_list = [] + + print("[A123] Creating triangle common-neighbors task...") + + for item in raw_triangles: + tri = item["tri"] + G = item["G"] + graph_idx = item["graph_idx"] + + a, b, c = tri["nodes"] + + # Compute common neighbours (exclude triangle nodes) + common = ( + set(G.neighbors(a)) & set(G.neighbors(b)) & set(G.neighbors(c)) + ) - {a, b, c} + num_common = len(common) + + # Cap at 8: 0-7 neighbors are their own class, 8+ neighbors are class 8 + label = min(num_common, 8) + + # Node degree features (structural) + deg_a = G.degree(a) + deg_b = G.degree(b) + deg_c = G.degree(c) + tri_feats = torch.tensor( + [deg_a, deg_b, deg_c], dtype=torch.float32 + ) + + tri_data = Data( + x=tri_feats.unsqueeze(0), # (1,3) + y=torch.tensor([int(label)], dtype=torch.long), + nodes=torch.tensor(tri["nodes"], dtype=torch.long), + role=tri.get("role", ""), + graph_idx=graph_idx, + ) + + triangle_data_list.append(tri_data) + + print(f"[A123] Created {len(triangle_data_list)} triangle CN samples") + return triangle_data_list + + def process(self) -> None: + """Generate raw files into collated PyG dataset and save to disk. + + This implementation mirrors other datasets in the repo: it calls the + static helper `extract_samples()` to enumerate subgraphs, converts each + to a `torch_geometric.data.Data` object via `_sample_to_pyg_data()`, + optionally computes/attaches topology vectors, collates and saves. + + If triangle_task is enabled, also creates and saves triangle-level dataset. + """ + data_dir = self.raw_dir + + print(f"[A123] Processing dataset from: {data_dir}") + print(f"[A123] Files in raw_dir: {os.listdir(data_dir)}") + + # extract sample descriptions + print("[A123] Starting extract_samples()...") + samples = A123CortexMDataset.extract_samples( + data_dir, self.n_bins, self.min_neurons + ) + + print(f"[A123] Extracted {len(samples)} samples") + + data_list = [] + skipped_count = 0 + for idx, (_, s) in enumerate(samples.iterrows()): + if idx % 100 == 0: + print( + f"[A123] Converting sample {idx}/{len(samples)} to PyG Data..." + ) + d = self._sample_to_pyg_data(s, threshold=self.corr_threshold) + # Filter out empty graphs (graphs with no edges) + if d.edge_index is not None and d.edge_index.numel() > 0: + data_list.append(d) + else: + skipped_count += 1 + + # collate and save processed dataset + print( + f"[A123] Collating {len(data_list)} samples (removed {skipped_count} empty graphs)..." + ) + self.data, self.slices = self.collate(data_list) + self._data_list = None + print(f"[A123] Saving processed data to {self.processed_paths[0]}...") + fs.torch_save( + (self._data.to_dict(), self.slices, {}, self._data.__class__), + self.processed_paths[0], + ) + + # If triangle task is enabled, create and save triangle classification dataset + specific_task = self.parameters.get("specific_task", "classification") + if specific_task == "triangle_classification": + print( + "[A123] Triangle task enabled. Creating triangle classification dataset..." + ) + triangle_data = self.create_triangle_classification_task() + + # Save triangle dataset to separate file + triangle_processed_path = self.processed_paths[0].replace( + "data.pt", "data_triangles.pt" + ) + print(f"[A123] Collating {len(triangle_data)} triangle samples...") + triangle_collated, triangle_slices = self.collate(triangle_data) + print( + f"[A123] Saving triangle dataset to {triangle_processed_path}..." + ) + fs.torch_save( + ( + triangle_collated.to_dict(), + triangle_slices, + {}, + triangle_collated.__class__, + ), + triangle_processed_path, + ) + print("[A123] Triangle task dataset saved!") + + # If triangle common-neighbours task is enabled, create and save it + if specific_task == "triangle_common_neighbors": + print( + "[A123] Triangle common-neighbours task enabled. Creating dataset..." + ) + triangle_cn_data = self.create_triangle_common_neighbors_task() + + triangle_cn_processed_path = self.processed_paths[0].replace( + "data.pt", "data_triangles_common_neighbors.pt" + ) + print( + f"[A123] Collating {len(triangle_cn_data)} triangle CN samples..." + ) + triangle_cn_collated, triangle_cn_slices = self.collate( + triangle_cn_data + ) + print( + f"[A123] Saving triangle CN dataset to {triangle_cn_processed_path}..." + ) + fs.torch_save( + ( + triangle_cn_collated.to_dict(), + triangle_cn_slices, + {}, + triangle_cn_collated.__class__, + ), + triangle_cn_processed_path, + ) + print("[A123] Triangle CN dataset saved!") + + print("[A123] Processing complete!") diff --git a/topobench/data/loaders/graph/a123_loader.py b/topobench/data/loaders/graph/a123_loader.py new file mode 100644 index 000000000..65e2d9ab3 --- /dev/null +++ b/topobench/data/loaders/graph/a123_loader.py @@ -0,0 +1,163 @@ +""" +Data loader for the Bowen et al. mouse auditory cortex calcium imaging dataset. + +This script downloads and processes the original dataset introduced in: + +[Citation] Bowen et al. (2024), "Fractured columnar small-world functional network +organization in volumes of L2/3 of mouse auditory cortex," PNAS Nexus, 3(2): pgae074. +https://doi.org/10.1093/pnasnexus/pgae074 + +We apply the preprocessing and graph-construction steps defined in this module to obtain +a representation of neuronal activity suitable for our experiments. + +Please cite the original paper when using this dataset or any derivatives. +""" + +import os.path as osp + +import torch +from omegaconf import DictConfig +from torch_geometric.io import fs + +from topobench.data.datasets.a123 import A123CortexMDataset +from topobench.data.loaders.base import AbstractLoader + + +class A123DatasetLoader(AbstractLoader): + """Loader for A123 mouse auditory cortex dataset. + + Implements the AbstractLoader interface: accepts a DictConfig `parameters` + and implements `load_dataset()` which returns a dataset object. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters for the dataset. + **overrides + Additional keyword arguments to override parameters. + """ + + def __init__(self, parameters: DictConfig, **overrides): + """Initialize the A123 dataset loader. + + Parameters + ---------- + parameters : DictConfig + Configuration parameters for the dataset. + **overrides + Additional keyword arguments to override parameters. + """ + # Initialize AbstractLoader (sets self.parameters and self.root_data_dir) + super().__init__(parameters) + + # hyperparameters can come from the DictConfig or be passed as overrides + params = parameters if parameters is not None else {} + + def _get(k, default): + """Get parameter value from DictConfig or overrides. + + Parameters + ---------- + k : str + Parameter key. + default : Any + Default value if key not found. + + Returns + ------- + Any + Parameter value from DictConfig or overrides, or default. + """ + try: + return params.get(k, overrides.get(k, default)) + except Exception: + # DictConfig may use attribute access + return getattr(params, k, overrides.get(k, default)) + + self.batch_size = int(_get("batch_size", 32)) + # dataset will be created when load_dataset() is called + self.dataset = None + + def load_dataset( + self, + ) -> torch.utils.data.Dataset: + """Instantiate and return the underlying dataset. + + Returns a `A123CortexMDataset` instance constructed from the loader's + parameters and root data directory. + + Returns + ------- + torch.utils.data.Dataset + A123CortexMDataset instance or triangle dataset. + """ + # determine dataset name from parameters, fallback to expected id + name = self.parameters.data_name + task_type = str( + getattr(self.parameters, "specific_task", "classification") + ) + + # root path for dataset: use the parent of root_data_dir since the dataset + # constructs its own subdirectory based on name + root = str(self.root_data_dir.parent) + + # Construct dataset; A123CortexMDataset expects (root, name, parameters) + self.dataset = A123CortexMDataset( + root=root, name=name, parameters=self.parameters + ) + + # If triangle task requested, load triangle dataset instead + if task_type == "triangle_classification": + # Load triangle classification dataset + processed_dir = self.dataset.processed_dir + triangle_data_path = osp.join(processed_dir, "data_triangles.pt") + + if osp.exists(triangle_data_path): + # Load triangle data + out = fs.torch_load(triangle_data_path) + assert len(out) == 4 + data, slices, sizes, data_cls = out + + if not isinstance(data, dict): + self.dataset.data = data + else: + self.dataset.data = data_cls.from_dict(data) + + self.dataset.slices = slices + print( + "[A123 Loader] Loaded triangle classification task dataset" + ) + else: + print( + f"[A123 Loader] Warning: Triangle dataset not found at {triangle_data_path}. " + f"Ensure triangle_task.enabled=true in config and dataset has been processed." + ) + + # Triangle common-neighbours task + if task_type == "triangle_common_neighbors": + processed_dir = self.dataset.processed_dir + triangle_cn_path = osp.join( + processed_dir, "data_triangles_common_neighbors.pt" + ) + + if osp.exists(triangle_cn_path): + out = fs.torch_load(triangle_cn_path) + assert len(out) == 4 + data, slices, sizes, data_cls = out + + if not isinstance(data, dict): + self.dataset.data = data + else: + self.dataset.data = data_cls.from_dict(data) + + self.dataset.slices = slices + print( + "[A123 Loader] Loaded triangle common-neighbours task dataset" + ) + else: + print( + f"[A123 Loader] Warning: Triangle CN dataset not found at {triangle_cn_path}. " + f"Ensure triangle_common_task.enabled=true in config and dataset has been processed." + ) + + return self.dataset diff --git a/topobench/data/utils/io_utils.py b/topobench/data/utils/io_utils.py index 372db85e6..a20913e8f 100644 --- a/topobench/data/utils/io_utils.py +++ b/topobench/data/utils/io_utils.py @@ -1,8 +1,11 @@ """Data IO utilities.""" +import glob import json +import os import os.path as osp import pickle +import time from urllib.parse import parse_qs, urlparse import numpy as np @@ -85,10 +88,19 @@ def download_file_from_drive( def download_file_from_link( - file_link, path_to_save, dataset_name, file_format="tar.gz" + file_link, + path_to_save, + dataset_name, + file_format="tar.gz", + verify=True, + timeout=None, + retries=3, ): """Download a file from a link and saves it to the specified path. + Uses streaming with chunked download and includes retry logic for + resilience against network interruptions. + Parameters ---------- file_link : str @@ -99,20 +111,171 @@ def download_file_from_link( The name of the dataset. file_format : str, optional The format of the downloaded file. Defaults to "tar.gz". + verify : bool, optional + Whether to verify SSL certificates. Defaults to True. + timeout : float, optional + Timeout in seconds per chunk read (not for entire download). For very slow + servers, increase this value. Default: 60 seconds per chunk. + retries : int, optional + Number of retry attempts if download fails. Defaults to 3. + + Notes + ----- + This function downloads files in 5MB chunks for memory efficiency. Progress is + reported every 10MB. Timeouts apply per chunk, not to the entire download, + making it suitable for very large files and slow connections. + + If a download fails, it retries with exponential backoff (5s, 10s, 15s). + + Examples + -------- + Basic download: + + >>> from topobench.data.utils import download_file_from_link + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset" + ... ) + + Download with custom timeout for slow servers: + + >>> download_file_from_link( + ... file_link="https://slow-server.com/dataset.zip", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... file_format="zip", + ... timeout=300 # 5 minutes per chunk + ... ) + + Download with increased retries for unreliable connections: + + >>> download_file_from_link( + ... file_link="https://example.com/dataset.tar.gz", + ... path_to_save="./data/", + ... dataset_name="my_dataset", + ... retries=5 # Try up to 5 times + ... ) Raises ------ - None + Exception + If download fails after all retry attempts. """ - response = requests.get(file_link) - + # Ensure output directory exists + os.makedirs(path_to_save, exist_ok=True) output_path = f"{path_to_save}/{dataset_name}.{file_format}" - if response.status_code == 200: - with open(output_path, "wb") as f: - f.write(response.content) - print("Download complete.") - else: - print("Failed to download the file.") + + # Default timeout: 60 seconds per chunk read (for very slow servers) + if timeout is None: + timeout = 60 + + for attempt in range(retries): + try: + print( + f"[Download] Starting download from: {file_link} (attempt {attempt + 1}/{retries})" + ) + + # Use tuple (connect_timeout, read_timeout) for proper streaming + response = requests.get( + file_link, + verify=verify, + stream=True, # Force streaming for chunked download + timeout=( + 30, + timeout, + ), # (connect timeout, read timeout per chunk) + ) + + if response.status_code != 200: + print( + f"[Download] Failed to download the file. HTTP {response.status_code}" + ) + return + + # Streaming download with progress reporting + total_size = int(response.headers.get("content-length", 0)) + downloaded = 0 + start_time = time.time() + + if total_size > 0: + print( + f"[Download] Total file size: {total_size / (1024**3):.2f} GB" + ) + else: + print("[Download] Total file size: unknown") + + # Stream download in chunks + chunk_size = 5 * 1024 * 1024 # 5MB chunks for faster throughput + progress_interval = ( + 10 * 1024 * 1024 + ) # Report progress every 10MB (for slow connections) + last_reported = 0 + + with open(output_path, "wb") as f: + for chunk in response.iter_content( + chunk_size=chunk_size, decode_unicode=False + ): + if chunk: + f.write(chunk) + f.flush() # Ensure data is written to disk + downloaded += len(chunk) + + # Print progress every 10MB + if ( + total_size > 0 + and (downloaded - last_reported) + >= progress_interval + ): + percent = (downloaded / total_size) * 100 + remaining = total_size - downloaded + elapsed_time = time.time() - start_time + speed_mbps = (downloaded / (1024**2)) / ( + elapsed_time + 0.001 + ) + + # Calculate ETA + if speed_mbps > 0: + eta_seconds = ( + remaining / (1024**2) / speed_mbps + ) + eta_hours = eta_seconds / 3600 + eta_minutes = (eta_seconds % 3600) / 60 + eta_str = ( + f"{eta_hours:.0f}h {eta_minutes:.0f}m" + ) + else: + eta_str = "calculating..." + + print( + f"[Download] {downloaded / (1024**3):.2f} / {total_size / (1024**3):.2f} GB ({percent:.1f}%) | Speed: {speed_mbps:.2f} MB/s | ETA: {eta_str}" + ) + last_reported = downloaded + + print(f"[Download] Download complete! Saved to: {output_path}") + break + + except ( + requests.exceptions.Timeout, + requests.exceptions.ConnectionError, + Exception, + ) as e: + print( + f"[Download] Download failed with error: {type(e).__name__}: {str(e)}" + ) + if attempt < retries - 1: + wait_time = 5 * ( + attempt + 1 + ) # Exponential backoff: 5s, 10s, 15s + print( + f"[Download] Retrying in {wait_time} seconds... (attempt {attempt + 2}/{retries})" + ) + time.sleep(wait_time) + else: + print( + f"[Download] Failed after {retries} attempts. Please check your connection and try again." + ) + raise e def read_ndim_manifolds( @@ -580,3 +743,113 @@ def load_hypergraph_content_dataset(data_dir, data_name): print("Final num_class", data.num_class) return data, data_dir + + +def collect_mat_files(data_dir: str) -> list: + """Collect all .mat files from a directory recursively. + + Excludes files containing "diffxy" in their names. + + Parameters + ---------- + data_dir : str + Root directory to search for .mat files. + + Returns + ------- + list + Sorted list of .mat file paths. + """ + patterns = [os.path.join(data_dir, "**", "*.mat")] + files = [] + for p in patterns: + files.extend(glob.glob(p, recursive=True)) + files = [f for f in files if "diffxy" not in f] + files.sort() + return files + + +def mat_cell_to_dict(mt) -> dict: + """Convert MATLAB cell array to dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array (structured array). + + Returns + ------- + dict + Dictionary with keys from cell array field names and squeezed values. + """ + clean_data = {} + keys = mt.dtype.names + for key_idx, key in enumerate(keys): + clean_data[key] = ( + np.squeeze(mt[key_idx]) + if isinstance(mt[key_idx], np.ndarray) + else mt[key_idx] + ) + return clean_data + + +def planewise_mat_cell_to_dict(mt) -> dict: + """Convert plane-wise MATLAB cell array to nested dictionary. + + Parameters + ---------- + mt : np.ndarray + MATLAB cell array with plane dimension. + + Returns + ------- + dict + Nested dictionary with plane IDs as keys. + """ + clean_data = {} + for plane_id in range(len(mt[0])): + keys = mt[0, plane_id].dtype.names + clean_data[plane_id] = {} + for key_idx, key in enumerate(keys): + clean_data[plane_id][key] = ( + np.squeeze(mt[0, plane_id][key_idx]) + if isinstance(mt[0, plane_id][key_idx], np.ndarray) + else mt[0, plane_id][key_idx] + ) + return clean_data + + +def process_mat(mat_data) -> dict: + """Generate MATLAB data structure into organized dictionary. + + Converts MATLAB cell arrays for BFInfo, CellInfo, CorrInfo, and other + experimental metadata into nested Python dictionaries. + + Parameters + ---------- + mat_data : dict + Dictionary loaded from MATLAB .mat file via scipy.io.loadmat. + + Returns + ------- + dict + Processed data structure with organized BFInfo, CellInfo, CorrInfo, + coordinate arrays, and experimental variables. + """ + mt = {} + mt["BFInfo"] = planewise_mat_cell_to_dict(mat_data["BFinfo"]) + mt["CellInfo"] = planewise_mat_cell_to_dict(mat_data["CellInfo"]) + mt["CorrInfo"] = planewise_mat_cell_to_dict(mat_data["CorrInfo"]) + mt["allZCorrInfo"] = mat_cell_to_dict(mat_data["allZCorrInfo"][0, 0]) + + for cord_key in ["allxc", "allyc", "allzc", "zDFF"]: + mt[cord_key] = {} + for p in range(mat_data[cord_key].shape[0]): + mt[cord_key][p] = mat_data[cord_key][p, 0] + + mt["exptVars"] = mat_cell_to_dict(mat_data["exptVars"][0, 0]) + mt["selectZCorrInfo"] = mat_cell_to_dict(mat_data["selectZCorrInfo"][0, 0]) + mt["stimInfo"] = planewise_mat_cell_to_dict(mat_data["stimInfo"]) + mt["zStuff"] = planewise_mat_cell_to_dict(mat_data["zStuff"]) + + return mt diff --git a/topobench/data/utils/triangle_classifier.py b/topobench/data/utils/triangle_classifier.py new file mode 100644 index 000000000..eec5e2616 --- /dev/null +++ b/topobench/data/utils/triangle_classifier.py @@ -0,0 +1,196 @@ +"""Generic triangle classification utilities for graph analysis. + +This module provides a base class for extracting and classifying triangles +in weighted graphs using efficient algorithms. +""" + +import networkx as nx +import torch + + +class TriangleClassifier: + """Base class for extracting and classifying triangles in graphs. + + Provides generic triangle enumeration and edge weight extraction. + Subclasses should override _classify_role() for domain-specific role definitions. + + Parameters + ---------- + min_weight : float, optional + Minimum edge weight to consider, by default 0.2. + """ + + def __init__(self, min_weight: float = 0.2): + """Initialize triangle classifier. + + Parameters + ---------- + min_weight : float, optional + Minimum edge weight to consider as valid edge, by default 0.2 + """ + self.min_weight = min_weight + + def enumerate_triangles(self, G: nx.Graph) -> list: + """Enumerate all triangles in a graph using efficient O(n^3) enumeration. + + Parameters + ---------- + G : nx.Graph + NetworkX graph object with edges and optional weights. + + Returns + ------- + list of tuple + Each tuple is (a, b, c) representing a triangle. + """ + triangles = [] + nodes = list(G.nodes()) + for i, a in enumerate(nodes): + neighbors_a = set(G.neighbors(a)) + for j, b in enumerate(nodes[i + 1 :], start=i + 1): + if b not in neighbors_a: + continue + neighbors_b = set(G.neighbors(b)) + for c in nodes[j + 1 :]: + if c in neighbors_a and c in neighbors_b: + # Found triangle (a,b,c) + triangles.append((a, b, c)) # noqa: PERF401 (conditional append, not extend) + + return triangles + + def classify_and_weight_triangles( + self, triangles: list, G: nx.Graph + ) -> list: + """Classify triangles and add edge weights and role information. + + Parameters + ---------- + triangles : list of tuple + List of triangles, each as (a, b, c) node indices. + G : nx.Graph + NetworkX graph with edge weights and adjacency information. + + Returns + ------- + list of dict + Each dict contains {'nodes': (a,b,c), 'edge_weights': [w1,w2,w3], 'role': str, 'label': int}. + """ + triangle_data = [] + for nodes in triangles: + a, b, c = nodes + + # Get edge weights + w_ab = G[a][b].get("weight", self.min_weight) + w_bc = G[b][c].get("weight", self.min_weight) + w_ac = G[a][c].get("weight", self.min_weight) + edge_weights_tri = [w_ab, w_bc, w_ac] + + # Classify role (subclasses override _classify_role for domain-specific logic) + role = self._classify_role(G, nodes, edge_weights_tri) + + triangle_data.append( + { + "nodes": nodes, + "edge_weights": edge_weights_tri, + "role": role, + "label": self._role_to_label(role), + } + ) + + return triangle_data + + def extract_triangles( + self, + edge_index: torch.Tensor, + edge_weights: torch.Tensor, + num_nodes: int, + ) -> list: + """Extract all triangles from graph (convenience method). + + Combines enumerate_triangles and classify_and_weight_triangles. + + Parameters + ---------- + edge_index : torch.Tensor + Edge connectivity, shape (2, num_edges). + edge_weights : torch.Tensor + Edge weights, shape (num_edges,). + num_nodes : int + Number of nodes. + + Returns + ------- + list of dict + Each dict contains {'nodes': (a,b,c), 'edge_weights': [w1,w2,w3], 'role': str, 'label': int}. + """ + # Build networkx graph + G = nx.Graph() + G.add_nodes_from(range(num_nodes)) + + for i in range(edge_index.shape[1]): + u = edge_index[0, i].item() + v = edge_index[1, i].item() + w = edge_weights[i].item() + G.add_edge(u, v, weight=w) + + # Enumerate and classify triangles + triangles = self.enumerate_triangles(G) + triangle_data = self.classify_and_weight_triangles(triangles, G) + + return triangle_data + + def _classify_role( + self, G: nx.Graph, nodes: tuple, edge_weights: list + ) -> str: + """Classify role of triangle based on edge weights and embedding. + + This method should be overridden in subclasses to provide domain-specific + role classification logic. + + Parameters + ---------- + G : nx.Graph + The graph. + nodes : tuple + Three node indices forming the triangle. + edge_weights : list + Three edge weights. + + Returns + ------- + str + Role string describing the triangle's role. + + Raises + ------ + NotImplementedError + If not overridden in subclass. + """ + raise NotImplementedError( + "Subclasses must override _classify_role() to define domain-specific role classification" + ) + + def _role_to_label(self, role_str: str) -> int: + """Convert role string to integer label. + + This method should be overridden in subclasses to define the mapping + from role strings to numeric labels. + + Parameters + ---------- + role_str : str + Role string (e.g., "core_strong"). + + Returns + ------- + int + Numeric label. + + Raises + ------ + NotImplementedError + If not overridden in subclass. + """ + raise NotImplementedError( + "Subclasses must override _role_to_label() to define role-to-label mapping" + ) diff --git a/tutorials/tutorial_train_brain_model.md b/tutorials/tutorial_train_brain_model.md new file mode 100644 index 000000000..3d5b58f39 --- /dev/null +++ b/tutorials/tutorial_train_brain_model.md @@ -0,0 +1,2015 @@ +# Training TBModel on Auditory Cortex Data + +This notebook demonstrates three different tasks using the A123 mouse auditory cortex dataset: + +1. **Graph-level Classification**: Predict frequency bin (0-8) from graph structure +2. **Triangle Classification**: Classify topological role of triangles in the correlation graph +3. **Triangle Common-Neighbors**: Predict the number of common neighbors for triangles + +We'll show how to load the dataset, apply lifting transformations, define a backbone, and train a `TBModel` using `TBLoss` and `TBOptimizer`. + +Requirements: the project installed in PYTHONPATH and optional dependencies (torch_geometric, networkx, ripser/persim) if you want advanced features. + + +```python +import os +os.chdir('..') +``` + + +```python +# 1) Imports +import torch +import numpy as np +import lightning as pl +from omegaconf import OmegaConf + +# Data loading / preprocessing utilities from the repo +from topobench.data.loaders.graph.a123_loader import A123DatasetLoader +from topobench.dataloader.dataloader import TBDataloader +from topobench.data.preprocessor import PreProcessor + +# Model / training building blocks +from topobench.model.model import TBModel +# example backbone building block (SCN2 is optional; we provide a tiny custom backbone below) +# from topomodelx.nn.simplicial.scn2 import SCN2 +from topobench.nn.wrappers.simplicial import SCNWrapper +from topobench.nn.encoders import AllCellFeatureEncoder +from topobench.nn.readouts import PropagateSignalDown + +# Optimization / evaluation +from topobench.loss.loss import TBLoss +from topobench.optimizer import TBOptimizer +from topobench.evaluator.evaluator import TBEvaluator + +print('Imports OK') +``` + + Imports OK + + + /Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/outdated/__init__.py:36: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81. + from pkg_resources import parse_version + + + +```python +# 2) Configurations for different tasks +# Note: We'll demonstrate each task separately by changing the specific_task parameter + +# Common loader config +loader_config_base = { + 'data_domain': 'graph', + 'data_type': 'A123', + 'data_name': 'a123_cortex_m', + 'data_dir': './data/a123/', + 'corr_threshold': 0.3, # Higher threshold ensures graphs have meaningful edges +} + +# Transform config: using CellCycleLifting (more robust for graphs with few edges) +# CellCycleLifting finds cycles and lifts them to 2-cells, handles empty graphs gracefully +transform_config = { + 'transform_type': 'lifting', + 'transform_name': 'CellCycleLifting', + 'max_cell_length': None, # No limit on cycle length +} + +split_config = { + 'learning_setting': 'inductive', + 'split_type': 'random', + 'data_seed': 0, + 'data_split_dir': './data/a123/splits/', + 'train_prop': 0.5, +} + +# Task configurations +tasks = { + 'graph_classification': { + 'description': 'Graph-level classification (predict frequency bin 0-8)', + 'specific_task': 'classification', + 'in_channels': 3, + 'out_channels': 9, + 'task_level': 'graph', + }, + 'triangle_classification': { + 'description': 'Triangle classification (predict topological role, 9 classes)', + 'specific_task': 'triangle_classification', + 'in_channels': 3, + 'out_channels': 9, + 'task_level': 'graph', + }, + 'triangle_common_neighbors': { + 'description': 'Triangle common-neighbors (predict # common neighbors, 9 classes)', + 'specific_task': 'triangle_common_neighbors', + 'in_channels': 3, + 'out_channels': 9, + 'task_level': 'graph', + } +} + +# Select task to run (change to 'classification', 'triangle_classification' or 'triangle_common_neighbors' to run different tasks) +TASK_NAME = 'triangle_common_neighbors' +TASK_CONFIG = tasks[TASK_NAME] + +print(f"Selected task: {TASK_NAME}") +print(f"Description: {TASK_CONFIG['description']}") + +# Create loader config with specific task +loader_config = OmegaConf.create({**loader_config_base, 'specific_task': TASK_CONFIG['specific_task']}) + +dim_hidden = 16 +in_channels = TASK_CONFIG['in_channels'] +out_channels = TASK_CONFIG['out_channels'] + +readout_config = { + 'readout_name': 'PropagateSignalDown', + 'num_cell_dimensions': 1, + 'hidden_dim': dim_hidden, + 'out_channels': out_channels, + 'task_level': TASK_CONFIG['task_level'], + 'pooling_type': 'sum', +} + +loss_config = { + 'dataset_loss': { + 'task': 'classification', + 'loss_type': 'cross_entropy', + } +} + +evaluator_config = { + 'task': 'classification', + 'num_classes': out_channels, + 'metrics': ['f1', 'precision', 'recall', 'accuracy'], +} + +optimizer_config = { + 'optimizer_id': 'Adam', + 'parameters': {'lr': 0.001, 'weight_decay': 0.0005}, +} + +# Convert to OmegaConf +transform_config = OmegaConf.create(transform_config) +split_config = OmegaConf.create(split_config) +readout_config = OmegaConf.create(readout_config) +loss_config = OmegaConf.create(loss_config) +evaluator_config = OmegaConf.create(evaluator_config) +optimizer_config = OmegaConf.create(optimizer_config) + +print('Configs created') +print(f"Loader config: {loader_config}") +print(f"Input channels: {in_channels}, Output channels: {out_channels}") +``` + + Selected task: triangle_common_neighbors + Description: Triangle common-neighbors (predict # common neighbors, 9 classes) + Configs created + Loader config: {'data_domain': 'graph', 'data_type': 'A123', 'data_name': 'a123_cortex_m', 'data_dir': './data/a123/', 'corr_threshold': 0.3, 'specific_task': 'triangle_common_neighbors'} + Input channels: 3, Output channels: 9 + + + +```python +# 3) Loading the data + +# Use the A123-specific loader (A123DatasetLoader) to construct the dataset +graph_loader = A123DatasetLoader(loader_config) + +dataset, dataset_dir = graph_loader.load() +print(f'Dataset loaded: {len(dataset)} samples') + +# For triangle-level tasks, skip lifting transformations (triangles have no edge_index) +# Only apply lifting for graph-level classification +task_type = TASK_CONFIG['specific_task'] +if task_type in ['triangle_classification', 'triangle_common_neighbors']: + # Skip lifting for triangle tasks - they don't have graph structure + print(f"Task '{task_type}' uses triangle-level features (no edge_index)") + print("Skipping lifting transformation for triangle data") + preprocessor = PreProcessor(dataset, dataset_dir, transforms_config=None) +else: + # Apply lifting for graph-level tasks + preprocessor = PreProcessor(dataset, dataset_dir, transform_config) + +dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config) +print(f'Dataset splits created:') +print(f' Train: {len(dataset_train)} samples') +print(f' Val: {len(dataset_val)} samples') +print(f' Test: {len(dataset_test)} samples') + +# create the TopoBench datamodule / dataloader wrappers +datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32) + +print('Datasets and datamodule ready') +``` + + Processing... + + + [A123] Processing dataset from: data/a123_cortex_m/raw + [A123] Files in raw_dir: ['Auditory cortex data', '__MACOSX'] + [A123] Starting extract_samples()... + Processing session 0: allPlanesVariables27-Feb-2021.mat + Processing session 1: allPlanesVariables27-Feb-2021.mat + Processing session 1: allPlanesVariables27-Feb-2021.mat + Processing session 2: allPlanesVariables27-Feb-2021.mat + Processing session 2: allPlanesVariables27-Feb-2021.mat + Processing session 3: allPlanesVariables27-Feb-2021.mat + Processing session 3: allPlanesVariables27-Feb-2021.mat + Processing session 4: allPlanesVariables27-Feb-2021.mat + Processing session 4: allPlanesVariables27-Feb-2021.mat + Processing session 5: allPlanesVariables27-Feb-2021.mat + Processing session 5: allPlanesVariables27-Feb-2021.mat + Processing session 6: allPlanesVariables27-Feb-2021.mat + Processing session 6: allPlanesVariables27-Feb-2021.mat + Processing session 7: allPlanesVariables27-Feb-2021.mat + Processing session 7: allPlanesVariables27-Feb-2021.mat + [A123] Extracted 250 samples + [A123] Converting sample 0/250 to PyG Data... + [A123] Converting sample 100/250 to PyG Data... + [A123] Converting sample 200/250 to PyG Data... + [A123] Collating 250 samples (removed 0 empty graphs)... + [A123] Saving processed data to data/a123_cortex_m/processed/data.pt... + [A123] Triangle common-neighbours task enabled. Creating dataset... + [A123] Starting triangle extraction from graphs... + [A123] Processing graph 0: extracting triangles... + [A123] Found 2268 triangles in graph 0 + [A123] Processing graph 1: extracting triangles... + [A123] Found 1277 triangles in graph 1 + [A123] Processing graph 2: extracting triangles... + [A123] Found 946 triangles in graph 2 + [A123] Processing graph 3: extracting triangles... + [A123] Found 462 triangles in graph 3 + [A123] Processing graph 4: extracting triangles... + [A123] Found 334 triangles in graph 4 + [A123] Processing graph 5: extracting triangles... + [A123] Found 304 triangles in graph 5 + [A123] Processing graph 6: extracting triangles... + [A123] Found 5021 triangles in graph 6 + [A123] Processing graph 7: extracting triangles... + [A123] Found 1343 triangles in graph 7 + [A123] Processing graph 8: extracting triangles... + [A123] Found 246 triangles in graph 8 + [A123] Processing graph 9: extracting triangles... + [A123] Found 80 triangles in graph 9 + [A123] Processed 10/250 graphs (0s elapsed, ~2s remaining)... + [A123] Processing graph 10: extracting triangles... + [A123] Found 69 triangles in graph 10 + [A123] Processing graph 11: extracting triangles... + [A123] Found 303 triangles in graph 11 + [A123] Processing graph 12: extracting triangles... + [A123] Found 4 triangles in graph 12 + [A123] Processing graph 13: extracting triangles... + [A123] Found 2 triangles in graph 13 + [A123] Processing graph 14: extracting triangles... + [A123] Found 404 triangles in graph 14 + [A123] Processing graph 15: extracting triangles... + [A123] Found 399 triangles in graph 15 + [A123] Processing graph 16: extracting triangles... + [A123] Found 291 triangles in graph 16 + [A123] Processing graph 17: extracting triangles... + [A123] Found 766 triangles in graph 17 + [A123] Processing graph 18: extracting triangles... + [A123] Found 179 triangles in graph 18 + [A123] Processing graph 19: extracting triangles... + [A123] Found 95 triangles in graph 19 + [A123] Processed 20/250 graphs (0s elapsed, ~1s remaining)... + [A123] Processing graph 20: extracting triangles... + [A123] Found 41 triangles in graph 20 + [A123] Processing graph 21: extracting triangles... + [A123] Found 6 triangles in graph 21 + [A123] Processing graph 22: extracting triangles... + [A123] Found 114 triangles in graph 22 + [A123] Processing graph 23: extracting triangles... + [A123] Found 381 triangles in graph 23 + [A123] Processing graph 24: extracting triangles... + [A123] Found 72 triangles in graph 24 + [A123] Processing graph 25: extracting triangles... + [A123] Found 477 triangles in graph 25 + [A123] Extracted 250 samples + [A123] Converting sample 0/250 to PyG Data... + [A123] Converting sample 100/250 to PyG Data... + [A123] Converting sample 200/250 to PyG Data... + [A123] Collating 250 samples (removed 0 empty graphs)... + [A123] Saving processed data to data/a123_cortex_m/processed/data.pt... + [A123] Triangle common-neighbours task enabled. Creating dataset... + [A123] Starting triangle extraction from graphs... + [A123] Processing graph 0: extracting triangles... + [A123] Found 2268 triangles in graph 0 + [A123] Processing graph 1: extracting triangles... + [A123] Found 1277 triangles in graph 1 + [A123] Processing graph 2: extracting triangles... + [A123] Found 946 triangles in graph 2 + [A123] Processing graph 3: extracting triangles... + [A123] Found 462 triangles in graph 3 + [A123] Processing graph 4: extracting triangles... + [A123] Found 334 triangles in graph 4 + [A123] Processing graph 5: extracting triangles... + [A123] Found 304 triangles in graph 5 + [A123] Processing graph 6: extracting triangles... + [A123] Found 5021 triangles in graph 6 + [A123] Processing graph 7: extracting triangles... + [A123] Found 1343 triangles in graph 7 + [A123] Processing graph 8: extracting triangles... + [A123] Found 246 triangles in graph 8 + [A123] Processing graph 9: extracting triangles... + [A123] Found 80 triangles in graph 9 + [A123] Processed 10/250 graphs (0s elapsed, ~2s remaining)... + [A123] Processing graph 10: extracting triangles... + [A123] Found 69 triangles in graph 10 + [A123] Processing graph 11: extracting triangles... + [A123] Found 303 triangles in graph 11 + [A123] Processing graph 12: extracting triangles... + [A123] Found 4 triangles in graph 12 + [A123] Processing graph 13: extracting triangles... + [A123] Found 2 triangles in graph 13 + [A123] Processing graph 14: extracting triangles... + [A123] Found 404 triangles in graph 14 + [A123] Processing graph 15: extracting triangles... + [A123] Found 399 triangles in graph 15 + [A123] Processing graph 16: extracting triangles... + [A123] Found 291 triangles in graph 16 + [A123] Processing graph 17: extracting triangles... + [A123] Found 766 triangles in graph 17 + [A123] Processing graph 18: extracting triangles... + [A123] Found 179 triangles in graph 18 + [A123] Processing graph 19: extracting triangles... + [A123] Found 95 triangles in graph 19 + [A123] Processed 20/250 graphs (0s elapsed, ~1s remaining)... + [A123] Processing graph 20: extracting triangles... + [A123] Found 41 triangles in graph 20 + [A123] Processing graph 21: extracting triangles... + [A123] Found 6 triangles in graph 21 + [A123] Processing graph 22: extracting triangles... + [A123] Found 114 triangles in graph 22 + [A123] Processing graph 23: extracting triangles... + [A123] Found 381 triangles in graph 23 + [A123] Processing graph 24: extracting triangles... + [A123] Found 72 triangles in graph 24 + [A123] Processing graph 25: extracting triangles... + [A123] Found 477 triangles in graph 25 + [A123] Processing graph 26: extracting triangles... + [A123] Found 137 triangles in graph 26 + [A123] Processing graph 27: extracting triangles... + [A123] Found 48 triangles in graph 27 + [A123] Processing graph 28: extracting triangles... + [A123] Found 2 triangles in graph 28 + [A123] Processing graph 29: extracting triangles... + [A123] Found 3 triangles in graph 29 + [A123] Processed 30/250 graphs (0s elapsed, ~1s remaining)... + [A123] Processing graph 30: extracting triangles... + [A123] Found 309 triangles in graph 30 + [A123] Processing graph 31: extracting triangles... + [A123] Found 418 triangles in graph 31 + [A123] Processing graph 32: extracting triangles... + [A123] Found 230 triangles in graph 32 + [A123] Processing graph 33: extracting triangles... + [A123] Found 188 triangles in graph 33 + [A123] Processing graph 34: extracting triangles... + [A123] Found 299 triangles in graph 34 + [A123] Processing graph 35: extracting triangles... + [A123] Found 79 triangles in graph 35 + [A123] Processing graph 36: extracting triangles... + [A123] Found 22 triangles in graph 36 + [A123] Processing graph 37: extracting triangles... + [A123] Found 3 triangles in graph 37 + [A123] Processing graph 38: extracting triangles... + [A123] Found 10 triangles in graph 38 + [A123] Processing graph 39: extracting triangles... + [A123] Found 86 triangles in graph 39 + [A123] Processed 40/250 graphs (0s elapsed, ~0s remaining)... + [A123] Processing graph 40: extracting triangles... + [A123] Found 10005 triangles in graph 40 + [A123] Processing graph 41: extracting triangles... + [A123] Found 476 triangles in graph 41 + [A123] Processing graph 42: extracting triangles... + [A123] Processing graph 26: extracting triangles... + [A123] Found 137 triangles in graph 26 + [A123] Processing graph 27: extracting triangles... + [A123] Found 48 triangles in graph 27 + [A123] Processing graph 28: extracting triangles... + [A123] Found 2 triangles in graph 28 + [A123] Processing graph 29: extracting triangles... + [A123] Found 3 triangles in graph 29 + [A123] Processed 30/250 graphs (0s elapsed, ~1s remaining)... + [A123] Processing graph 30: extracting triangles... + [A123] Found 309 triangles in graph 30 + [A123] Processing graph 31: extracting triangles... + [A123] Found 418 triangles in graph 31 + [A123] Processing graph 32: extracting triangles... + [A123] Found 230 triangles in graph 32 + [A123] Processing graph 33: extracting triangles... + [A123] Found 188 triangles in graph 33 + [A123] Processing graph 34: extracting triangles... + [A123] Found 299 triangles in graph 34 + [A123] Processing graph 35: extracting triangles... + [A123] Found 79 triangles in graph 35 + [A123] Processing graph 36: extracting triangles... + [A123] Found 22 triangles in graph 36 + [A123] Processing graph 37: extracting triangles... + [A123] Found 3 triangles in graph 37 + [A123] Processing graph 38: extracting triangles... + [A123] Found 10 triangles in graph 38 + [A123] Processing graph 39: extracting triangles... + [A123] Found 86 triangles in graph 39 + [A123] Processed 40/250 graphs (0s elapsed, ~0s remaining)... + [A123] Processing graph 40: extracting triangles... + [A123] Found 10005 triangles in graph 40 + [A123] Processing graph 41: extracting triangles... + [A123] Found 476 triangles in graph 41 + [A123] Processing graph 42: extracting triangles... + [A123] Found 43198 triangles in graph 42 + [A123] Processing graph 43: extracting triangles... + [A123] Found 20598 triangles in graph 43 + [A123] Processing graph 44: extracting triangles... + [A123] Found 135 triangles in graph 44 + [A123] Processing graph 45: extracting triangles... + [A123] Found 2451 triangles in graph 45 + [A123] Processing graph 46: extracting triangles... + [A123] Found 8356 triangles in graph 46 + [A123] Found 43198 triangles in graph 42 + [A123] Processing graph 43: extracting triangles... + [A123] Found 20598 triangles in graph 43 + [A123] Processing graph 44: extracting triangles... + [A123] Found 135 triangles in graph 44 + [A123] Processing graph 45: extracting triangles... + [A123] Found 2451 triangles in graph 45 + [A123] Processing graph 46: extracting triangles... + [A123] Found 8356 triangles in graph 46 + [A123] Processing graph 47: extracting triangles... + [A123] Found 231 triangles in graph 47 + [A123] Processing graph 48: extracting triangles... + [A123] Found 8334 triangles in graph 48 + [A123] Processing graph 49: extracting triangles... + [A123] Found 784 triangles in graph 49 + [A123] Processed 50/250 graphs (0s elapsed, ~2s remaining)... + [A123] Processing graph 50: extracting triangles... + [A123] Found 5156 triangles in graph 50 + [A123] Processing graph 51: extracting triangles... + [A123] Found 4752 triangles in graph 51 + [A123] Processing graph 52: extracting triangles... + [A123] Found 31 triangles in graph 52 + [A123] Processing graph 53: extracting triangles... + [A123] Found 1268 triangles in graph 53 + [A123] Processing graph 54: extracting triangles... + [A123] Found 831 triangles in graph 54 + [A123] Processing graph 55: extracting triangles... + [A123] Found 110 triangles in graph 55 + [A123] Processing graph 56: extracting triangles... + [A123] Found 4363 triangles in graph 56 + [A123] Processing graph 47: extracting triangles... + [A123] Found 231 triangles in graph 47 + [A123] Processing graph 48: extracting triangles... + [A123] Found 8334 triangles in graph 48 + [A123] Processing graph 49: extracting triangles... + [A123] Found 784 triangles in graph 49 + [A123] Processed 50/250 graphs (0s elapsed, ~2s remaining)... + [A123] Processing graph 50: extracting triangles... + [A123] Found 5156 triangles in graph 50 + [A123] Processing graph 51: extracting triangles... + [A123] Found 4752 triangles in graph 51 + [A123] Processing graph 52: extracting triangles... + [A123] Found 31 triangles in graph 52 + [A123] Processing graph 53: extracting triangles... + [A123] Found 1268 triangles in graph 53 + [A123] Processing graph 54: extracting triangles... + [A123] Found 831 triangles in graph 54 + [A123] Processing graph 55: extracting triangles... + [A123] Found 110 triangles in graph 55 + [A123] Processing graph 56: extracting triangles... + [A123] Found 4363 triangles in graph 56 + [A123] Processing graph 57: extracting triangles... + [A123] Found 729 triangles in graph 57 + [A123] Processing graph 58: extracting triangles... + [A123] Found 13722 triangles in graph 58 + [A123] Processing graph 59: extracting triangles... + [A123] Found 3316 triangles in graph 59 + [A123] Processed 60/250 graphs (1s elapsed, ~3s remaining)... + [A123] Processing graph 60: extracting triangles... + [A123] Found 123 triangles in graph 60 + [A123] Processing graph 61: extracting triangles... + [A123] Found 432 triangles in graph 61 + [A123] Processing graph 62: extracting triangles... + [A123] Found 1163 triangles in graph 62 + [A123] Processing graph 63: extracting triangles... + [A123] Found 2 triangles in graph 63 + [A123] Processing graph 64: extracting triangles... + [A123] Found 2972 triangles in graph 64 + [A123] Processing graph 65: extracting triangles... + [A123] Found 589 triangles in graph 65 + [A123] Processing graph 66: extracting triangles... + [A123] Found 4603 triangles in graph 66 + [A123] Processing graph 57: extracting triangles... + [A123] Found 729 triangles in graph 57 + [A123] Processing graph 58: extracting triangles... + [A123] Found 13722 triangles in graph 58 + [A123] Processing graph 59: extracting triangles... + [A123] Found 3316 triangles in graph 59 + [A123] Processed 60/250 graphs (1s elapsed, ~3s remaining)... + [A123] Processing graph 60: extracting triangles... + [A123] Found 123 triangles in graph 60 + [A123] Processing graph 61: extracting triangles... + [A123] Found 432 triangles in graph 61 + [A123] Processing graph 62: extracting triangles... + [A123] Found 1163 triangles in graph 62 + [A123] Processing graph 63: extracting triangles... + [A123] Found 2 triangles in graph 63 + [A123] Processing graph 64: extracting triangles... + [A123] Found 2972 triangles in graph 64 + [A123] Processing graph 65: extracting triangles... + [A123] Found 589 triangles in graph 65 + [A123] Processing graph 66: extracting triangles... + [A123] Found 4603 triangles in graph 66 + [A123] Processing graph 67: extracting triangles... + [A123] Found 4089 triangles in graph 67 + [A123] Processing graph 68: extracting triangles... + [A123] Found 76 triangles in graph 68 + [A123] Processing graph 69: extracting triangles... + [A123] Found 1092 triangles in graph 69 + [A123] Processed 70/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 70: extracting triangles... + [A123] Found 976 triangles in graph 70 + [A123] Processing graph 71: extracting triangles... + [A123] Found 79 triangles in graph 71 + [A123] Processing graph 72: extracting triangles... + [A123] Found 3432 triangles in graph 72 + [A123] Processing graph 73: extracting triangles... + [A123] Found 193 triangles in graph 73 + [A123] Processing graph 74: extracting triangles... + [A123] Found 8775 triangles in graph 74 + [A123] Processing graph 75: extracting triangles... + [A123] Found 1758 triangles in graph 75 + [A123] Processing graph 76: extracting triangles... + [A123] Found 62 triangles in graph 76 + [A123] Processing graph 77: extracting triangles... + [A123] Found 746 triangles in graph 77 + [A123] Processing graph 78: extracting triangles... + [A123] Found 628 triangles in graph 78 + [A123] Processing graph 79: extracting triangles... + [A123] Found 117 triangles in graph 79 + [A123] Processed 80/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 80: extracting triangles... + [A123] Found 169 triangles in graph 80 + [A123] Processing graph 81: extracting triangles... + [A123] Found 36 triangles in graph 81 + [A123] Processing graph 82: extracting triangles... + [A123] Processing graph 67: extracting triangles... + [A123] Found 4089 triangles in graph 67 + [A123] Processing graph 68: extracting triangles... + [A123] Found 76 triangles in graph 68 + [A123] Processing graph 69: extracting triangles... + [A123] Found 1092 triangles in graph 69 + [A123] Processed 70/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 70: extracting triangles... + [A123] Found 976 triangles in graph 70 + [A123] Processing graph 71: extracting triangles... + [A123] Found 79 triangles in graph 71 + [A123] Processing graph 72: extracting triangles... + [A123] Found 3432 triangles in graph 72 + [A123] Processing graph 73: extracting triangles... + [A123] Found 193 triangles in graph 73 + [A123] Processing graph 74: extracting triangles... + [A123] Found 8775 triangles in graph 74 + [A123] Processing graph 75: extracting triangles... + [A123] Found 1758 triangles in graph 75 + [A123] Processing graph 76: extracting triangles... + [A123] Found 62 triangles in graph 76 + [A123] Processing graph 77: extracting triangles... + [A123] Found 746 triangles in graph 77 + [A123] Processing graph 78: extracting triangles... + [A123] Found 628 triangles in graph 78 + [A123] Processing graph 79: extracting triangles... + [A123] Found 117 triangles in graph 79 + [A123] Processed 80/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 80: extracting triangles... + [A123] Found 169 triangles in graph 80 + [A123] Processing graph 81: extracting triangles... + [A123] Found 36 triangles in graph 81 + [A123] Processing graph 82: extracting triangles... + [A123] Found 25415 triangles in graph 82 + [A123] Processing graph 83: extracting triangles... + [A123] Found 1636 triangles in graph 83 + [A123] Processing graph 84: extracting triangles... + [A123] Found 1006 triangles in graph 84 + [A123] Processing graph 85: extracting triangles... + [A123] Found 141 triangles in graph 85 + [A123] Processing graph 86: extracting triangles... + [A123] Found 21 triangles in graph 86 + [A123] Processing graph 87: extracting triangles... + [A123] Found 41 triangles in graph 87 + [A123] Processing graph 88: extracting triangles... + [A123] Found 37 triangles in graph 88 + [A123] Processing graph 89: extracting triangles... + [A123] Found 7975 triangles in graph 89 + [A123] Processed 90/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 90: extracting triangles... + [A123] Found 3172 triangles in graph 90 + [A123] Processing graph 91: extracting triangles... + [A123] Found 242 triangles in graph 91 + [A123] Processing graph 92: extracting triangles... + [A123] Found 145 triangles in graph 92 + [A123] Processing graph 93: extracting triangles... + [A123] Found 16 triangles in graph 93 + [A123] Processing graph 94: extracting triangles... + [A123] Found 77 triangles in graph 94 + [A123] Processing graph 95: extracting triangles... + [A123] Found 2 triangles in graph 95 + [A123] Processing graph 96: extracting triangles... + [A123] Found 25415 triangles in graph 82 + [A123] Processing graph 83: extracting triangles... + [A123] Found 1636 triangles in graph 83 + [A123] Processing graph 84: extracting triangles... + [A123] Found 1006 triangles in graph 84 + [A123] Processing graph 85: extracting triangles... + [A123] Found 141 triangles in graph 85 + [A123] Processing graph 86: extracting triangles... + [A123] Found 21 triangles in graph 86 + [A123] Processing graph 87: extracting triangles... + [A123] Found 41 triangles in graph 87 + [A123] Processing graph 88: extracting triangles... + [A123] Found 37 triangles in graph 88 + [A123] Processing graph 89: extracting triangles... + [A123] Found 7975 triangles in graph 89 + [A123] Processed 90/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 90: extracting triangles... + [A123] Found 3172 triangles in graph 90 + [A123] Processing graph 91: extracting triangles... + [A123] Found 242 triangles in graph 91 + [A123] Processing graph 92: extracting triangles... + [A123] Found 145 triangles in graph 92 + [A123] Processing graph 93: extracting triangles... + [A123] Found 16 triangles in graph 93 + [A123] Processing graph 94: extracting triangles... + [A123] Found 77 triangles in graph 94 + [A123] Processing graph 95: extracting triangles... + [A123] Found 2 triangles in graph 95 + [A123] Processing graph 96: extracting triangles... + [A123] Found 14369 triangles in graph 96 + [A123] Processing graph 97: extracting triangles... + [A123] Found 3609 triangles in graph 97 + [A123] Processing graph 98: extracting triangles... + [A123] Found 676 triangles in graph 98 + [A123] Processing graph 99: extracting triangles... + [A123] Found 88 triangles in graph 99 + [A123] Processed 100/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 100: extracting triangles... + [A123] Found 32 triangles in graph 100 + [A123] Processing graph 101: extracting triangles... + [A123] Found 61 triangles in graph 101 + [A123] Processing graph 102: extracting triangles... + [A123] Found 31 triangles in graph 102 + [A123] Processing graph 103: extracting triangles... + [A123] Found 19245 triangles in graph 103 + [A123] Processing graph 104: extracting triangles... + [A123] Found 14369 triangles in graph 96 + [A123] Processing graph 97: extracting triangles... + [A123] Found 3609 triangles in graph 97 + [A123] Processing graph 98: extracting triangles... + [A123] Found 676 triangles in graph 98 + [A123] Processing graph 99: extracting triangles... + [A123] Found 88 triangles in graph 99 + [A123] Processed 100/250 graphs (1s elapsed, ~2s remaining)... + [A123] Processing graph 100: extracting triangles... + [A123] Found 32 triangles in graph 100 + [A123] Processing graph 101: extracting triangles... + [A123] Found 61 triangles in graph 101 + [A123] Processing graph 102: extracting triangles... + [A123] Found 31 triangles in graph 102 + [A123] Processing graph 103: extracting triangles... + [A123] Found 19245 triangles in graph 103 + [A123] Processing graph 104: extracting triangles... + [A123] Found 6080 triangles in graph 104 + [A123] Processing graph 105: extracting triangles... + [A123] Found 574 triangles in graph 105 + [A123] Processing graph 106: extracting triangles... + [A123] Found 236 triangles in graph 106 + [A123] Processing graph 107: extracting triangles... + [A123] Found 23 triangles in graph 107 + [A123] Processing graph 108: extracting triangles... + [A123] Found 36 triangles in graph 108 + [A123] Processing graph 109: extracting triangles... + [A123] Found 6080 triangles in graph 104 + [A123] Processing graph 105: extracting triangles... + [A123] Found 574 triangles in graph 105 + [A123] Processing graph 106: extracting triangles... + [A123] Found 236 triangles in graph 106 + [A123] Processing graph 107: extracting triangles... + [A123] Found 23 triangles in graph 107 + [A123] Processing graph 108: extracting triangles... + [A123] Found 36 triangles in graph 108 + [A123] Processing graph 109: extracting triangles... + [A123] Found 36818 triangles in graph 109 + [A123] Processed 110/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 110: extracting triangles... + [A123] Found 11193 triangles in graph 110 + [A123] Processing graph 111: extracting triangles... + [A123] Found 474 triangles in graph 111 + [A123] Processing graph 112: extracting triangles... + [A123] Found 44 triangles in graph 112 + [A123] Processing graph 113: extracting triangles... + [A123] Found 12 triangles in graph 113 + [A123] Processing graph 114: extracting triangles... + [A123] Found 297 triangles in graph 114 + [A123] Processing graph 115: extracting triangles... + [A123] Found 232 triangles in graph 115 + [A123] Processing graph 116: extracting triangles... + [A123] Found 303 triangles in graph 116 + [A123] Processing graph 117: extracting triangles... + [A123] Found 44 triangles in graph 117 + [A123] Processing graph 118: extracting triangles... + [A123] Found 111 triangles in graph 118 + [A123] Processing graph 119: extracting triangles... + [A123] Found 463 triangles in graph 119 + [A123] Processed 120/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 120: extracting triangles... + [A123] Found 42 triangles in graph 120 + [A123] Processing graph 121: extracting triangles... + [A123] Found 159 triangles in graph 121 + [A123] Processing graph 122: extracting triangles... + [A123] Found 123 triangles in graph 122 + [A123] Processing graph 123: extracting triangles... + [A123] Found 12 triangles in graph 123 + [A123] Processing graph 124: extracting triangles... + [A123] Found 92 triangles in graph 124 + [A123] Processing graph 125: extracting triangles... + [A123] Found 177 triangles in graph 125 + [A123] Processing graph 126: extracting triangles... + [A123] Found 2 triangles in graph 126 + [A123] Processing graph 127: extracting triangles... + [A123] Found 4 triangles in graph 127 + [A123] Processing graph 128: extracting triangles... + [A123] Found 67 triangles in graph 128 + [A123] Processing graph 129: extracting triangles... + [A123] Found 17 triangles in graph 129 + [A123] Processed 130/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 130: extracting triangles... + [A123] Found 27 triangles in graph 130 + [A123] Processing graph 131: extracting triangles... + [A123] Found 36 triangles in graph 131 + [A123] Processing graph 132: extracting triangles... + [A123] Found 55 triangles in graph 132 + [A123] Processing graph 133: extracting triangles... + [A123] Found 25 triangles in graph 133 + [A123] Processing graph 134: extracting triangles... + [A123] Found 1 triangles in graph 134 + [A123] Processing graph 135: extracting triangles... + [A123] Found 1 triangles in graph 135 + [A123] Processing graph 136: extracting triangles... + [A123] Found 11 triangles in graph 136 + [A123] Processing graph 137: extracting triangles... + [A123] Found 3 triangles in graph 137 + [A123] Processing graph 138: extracting triangles... + [A123] Found 14 triangles in graph 138 + [A123] Processing graph 139: extracting triangles... + [A123] Found 53 triangles in graph 139 + [A123] Processed 140/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 140: extracting triangles... + [A123] Found 6 triangles in graph 140 + [A123] Processing graph 141: extracting triangles... + [A123] Found 11 triangles in graph 141 + [A123] Processing graph 142: extracting triangles... + [A123] Found 52 triangles in graph 142 + [A123] Processing graph 143: extracting triangles... + [A123] Found 8 triangles in graph 143 + [A123] Processing graph 144: extracting triangles... + [A123] Found 6 triangles in graph 144 + [A123] Processing graph 145: extracting triangles... + [A123] Found 130 triangles in graph 145 + [A123] Processing graph 146: extracting triangles... + [A123] Found 317 triangles in graph 146 + [A123] Processing graph 147: extracting triangles... + [A123] Found 60 triangles in graph 147 + [A123] Processing graph 148: extracting triangles... + [A123] Found 10 triangles in graph 148 + [A123] Processing graph 149: extracting triangles... + [A123] Found 123 triangles in graph 149 + [A123] Processed 150/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 150: extracting triangles... + [A123] Found 154 triangles in graph 150 + [A123] Processing graph 151: extracting triangles... + [A123] Found 13 triangles in graph 151 + [A123] Processing graph 152: extracting triangles... + [A123] Found 7 triangles in graph 152 + [A123] Processing graph 153: extracting triangles... + [A123] Found 5 triangles in graph 153 + [A123] Processing graph 154: extracting triangles... + [A123] Found 119 triangles in graph 154 + [A123] Processing graph 155: extracting triangles... + [A123] Found 160 triangles in graph 155 + [A123] Processing graph 156: extracting triangles... + [A123] Found 15 triangles in graph 156 + [A123] Processing graph 157: extracting triangles... + [A123] Found 10 triangles in graph 157 + [A123] Processing graph 158: extracting triangles... + [A123] Found 8 triangles in graph 158 + [A123] Processing graph 159: extracting triangles... + [A123] Found 30 triangles in graph 159 + [A123] Processed 160/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 160: extracting triangles... + [A123] Found 13 triangles in graph 160 + [A123] Processing graph 161: extracting triangles... + [A123] Found 183 triangles in graph 161 + [A123] Processing graph 162: extracting triangles... + [A123] Found 177 triangles in graph 162 + [A123] Processing graph 163: extracting triangles... + [A123] Found 0 triangles in graph 163 + [A123] Processing graph 164: extracting triangles... + [A123] Found 53 triangles in graph 164 + [A123] Processing graph 165: extracting triangles... + [A123] Found 21 triangles in graph 165 + [A123] Processing graph 166: extracting triangles... + [A123] Found 21 triangles in graph 166 + [A123] Processing graph 167: extracting triangles... + [A123] Found 59 triangles in graph 167 + [A123] Processing graph 168: extracting triangles... + [A123] Found 52 triangles in graph 168 + [A123] Processing graph 169: extracting triangles... + [A123] Found 7 triangles in graph 169 + [A123] Processed 170/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 170: extracting triangles... + [A123] Found 35 triangles in graph 170 + [A123] Processing graph 171: extracting triangles... + [A123] Found 916 triangles in graph 171 + [A123] Processing graph 172: extracting triangles... + [A123] Found 142 triangles in graph 172 + [A123] Processing graph 173: extracting triangles... + [A123] Found 23 triangles in graph 173 + [A123] Processing graph 174: extracting triangles... + [A123] Found 467 triangles in graph 174 + [A123] Found 36818 triangles in graph 109 + [A123] Processed 110/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 110: extracting triangles... + [A123] Found 11193 triangles in graph 110 + [A123] Processing graph 111: extracting triangles... + [A123] Found 474 triangles in graph 111 + [A123] Processing graph 112: extracting triangles... + [A123] Found 44 triangles in graph 112 + [A123] Processing graph 113: extracting triangles... + [A123] Found 12 triangles in graph 113 + [A123] Processing graph 114: extracting triangles... + [A123] Found 297 triangles in graph 114 + [A123] Processing graph 115: extracting triangles... + [A123] Found 232 triangles in graph 115 + [A123] Processing graph 116: extracting triangles... + [A123] Found 303 triangles in graph 116 + [A123] Processing graph 117: extracting triangles... + [A123] Found 44 triangles in graph 117 + [A123] Processing graph 118: extracting triangles... + [A123] Found 111 triangles in graph 118 + [A123] Processing graph 119: extracting triangles... + [A123] Found 463 triangles in graph 119 + [A123] Processed 120/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 120: extracting triangles... + [A123] Found 42 triangles in graph 120 + [A123] Processing graph 121: extracting triangles... + [A123] Found 159 triangles in graph 121 + [A123] Processing graph 122: extracting triangles... + [A123] Found 123 triangles in graph 122 + [A123] Processing graph 123: extracting triangles... + [A123] Found 12 triangles in graph 123 + [A123] Processing graph 124: extracting triangles... + [A123] Found 92 triangles in graph 124 + [A123] Processing graph 125: extracting triangles... + [A123] Found 177 triangles in graph 125 + [A123] Processing graph 126: extracting triangles... + [A123] Found 2 triangles in graph 126 + [A123] Processing graph 127: extracting triangles... + [A123] Found 4 triangles in graph 127 + [A123] Processing graph 128: extracting triangles... + [A123] Found 67 triangles in graph 128 + [A123] Processing graph 129: extracting triangles... + [A123] Found 17 triangles in graph 129 + [A123] Processed 130/250 graphs (2s elapsed, ~2s remaining)... + [A123] Processing graph 130: extracting triangles... + [A123] Found 27 triangles in graph 130 + [A123] Processing graph 131: extracting triangles... + [A123] Found 36 triangles in graph 131 + [A123] Processing graph 132: extracting triangles... + [A123] Found 55 triangles in graph 132 + [A123] Processing graph 133: extracting triangles... + [A123] Found 25 triangles in graph 133 + [A123] Processing graph 134: extracting triangles... + [A123] Found 1 triangles in graph 134 + [A123] Processing graph 135: extracting triangles... + [A123] Found 1 triangles in graph 135 + [A123] Processing graph 136: extracting triangles... + [A123] Found 11 triangles in graph 136 + [A123] Processing graph 137: extracting triangles... + [A123] Found 3 triangles in graph 137 + [A123] Processing graph 138: extracting triangles... + [A123] Found 14 triangles in graph 138 + [A123] Processing graph 139: extracting triangles... + [A123] Found 53 triangles in graph 139 + [A123] Processed 140/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 140: extracting triangles... + [A123] Found 6 triangles in graph 140 + [A123] Processing graph 141: extracting triangles... + [A123] Found 11 triangles in graph 141 + [A123] Processing graph 142: extracting triangles... + [A123] Found 52 triangles in graph 142 + [A123] Processing graph 143: extracting triangles... + [A123] Found 8 triangles in graph 143 + [A123] Processing graph 144: extracting triangles... + [A123] Found 6 triangles in graph 144 + [A123] Processing graph 145: extracting triangles... + [A123] Found 130 triangles in graph 145 + [A123] Processing graph 146: extracting triangles... + [A123] Found 317 triangles in graph 146 + [A123] Processing graph 147: extracting triangles... + [A123] Found 60 triangles in graph 147 + [A123] Processing graph 148: extracting triangles... + [A123] Found 10 triangles in graph 148 + [A123] Processing graph 149: extracting triangles... + [A123] Found 123 triangles in graph 149 + [A123] Processed 150/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 150: extracting triangles... + [A123] Found 154 triangles in graph 150 + [A123] Processing graph 151: extracting triangles... + [A123] Found 13 triangles in graph 151 + [A123] Processing graph 152: extracting triangles... + [A123] Found 7 triangles in graph 152 + [A123] Processing graph 153: extracting triangles... + [A123] Found 5 triangles in graph 153 + [A123] Processing graph 154: extracting triangles... + [A123] Found 119 triangles in graph 154 + [A123] Processing graph 155: extracting triangles... + [A123] Found 160 triangles in graph 155 + [A123] Processing graph 156: extracting triangles... + [A123] Found 15 triangles in graph 156 + [A123] Processing graph 157: extracting triangles... + [A123] Found 10 triangles in graph 157 + [A123] Processing graph 158: extracting triangles... + [A123] Found 8 triangles in graph 158 + [A123] Processing graph 159: extracting triangles... + [A123] Found 30 triangles in graph 159 + [A123] Processed 160/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 160: extracting triangles... + [A123] Found 13 triangles in graph 160 + [A123] Processing graph 161: extracting triangles... + [A123] Found 183 triangles in graph 161 + [A123] Processing graph 162: extracting triangles... + [A123] Found 177 triangles in graph 162 + [A123] Processing graph 163: extracting triangles... + [A123] Found 0 triangles in graph 163 + [A123] Processing graph 164: extracting triangles... + [A123] Found 53 triangles in graph 164 + [A123] Processing graph 165: extracting triangles... + [A123] Found 21 triangles in graph 165 + [A123] Processing graph 166: extracting triangles... + [A123] Found 21 triangles in graph 166 + [A123] Processing graph 167: extracting triangles... + [A123] Found 59 triangles in graph 167 + [A123] Processing graph 168: extracting triangles... + [A123] Found 52 triangles in graph 168 + [A123] Processing graph 169: extracting triangles... + [A123] Found 7 triangles in graph 169 + [A123] Processed 170/250 graphs (2s elapsed, ~1s remaining)... + [A123] Processing graph 170: extracting triangles... + [A123] Found 35 triangles in graph 170 + [A123] Processing graph 171: extracting triangles... + [A123] Found 916 triangles in graph 171 + [A123] Processing graph 172: extracting triangles... + [A123] Found 142 triangles in graph 172 + [A123] Processing graph 173: extracting triangles... + [A123] Found 23 triangles in graph 173 + [A123] Processing graph 174: extracting triangles... + [A123] Found 467 triangles in graph 174 + [A123] Processing graph 175: extracting triangles... + [A123] Found 27 triangles in graph 175 + [A123] Processing graph 176: extracting triangles... + [A123] Found 148 triangles in graph 176 + [A123] Processing graph 177: extracting triangles... + [A123] Found 65 triangles in graph 177 + [A123] Processing graph 178: extracting triangles... + [A123] Found 21 triangles in graph 178 + [A123] Processing graph 179: extracting triangles... + [A123] Found 80 triangles in graph 179 + [A123] Processed 180/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 180: extracting triangles... + [A123] Found 20 triangles in graph 180 + [A123] Processing graph 181: extracting triangles... + [A123] Found 10 triangles in graph 181 + [A123] Processing graph 182: extracting triangles... + [A123] Found 35 triangles in graph 182 + [A123] Processing graph 183: extracting triangles... + [A123] Found 32 triangles in graph 183 + [A123] Processing graph 184: extracting triangles... + [A123] Found 197 triangles in graph 184 + [A123] Processing graph 185: extracting triangles... + [A123] Found 26 triangles in graph 185 + [A123] Processing graph 186: extracting triangles... + [A123] Found 76 triangles in graph 186 + [A123] Processing graph 187: extracting triangles... + [A123] Found 54 triangles in graph 187 + [A123] Processing graph 188: extracting triangles... + [A123] Found 31 triangles in graph 188 + [A123] Processing graph 189: extracting triangles... + [A123] Found 22 triangles in graph 189 + [A123] Processed 190/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 190: extracting triangles... + [A123] Found 15 triangles in graph 190 + [A123] Processing graph 191: extracting triangles... + [A123] Found 107 triangles in graph 191 + [A123] Processing graph 192: extracting triangles... + [A123] Found 544 triangles in graph 192 + [A123] Processing graph 193: extracting triangles... + [A123] Found 29 triangles in graph 193 + [A123] Processing graph 194: extracting triangles... + [A123] Found 226 triangles in graph 194 + [A123] Processing graph 195: extracting triangles... + [A123] Found 298 triangles in graph 195 + [A123] Processing graph 196: extracting triangles... + [A123] Found 9 triangles in graph 196 + [A123] Processing graph 197: extracting triangles... + [A123] Found 5 triangles in graph 197 + [A123] Processing graph 198: extracting triangles... + [A123] Found 11 triangles in graph 198 + [A123] Processing graph 199: extracting triangles... + [A123] Found 166 triangles in graph 199 + [A123] Processed 200/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 200: extracting triangles... + [A123] Found 162 triangles in graph 200 + [A123] Processing graph 201: extracting triangles... + [A123] Found 88 triangles in graph 201 + [A123] Processing graph 202: extracting triangles... + [A123] Found 139 triangles in graph 202 + [A123] Processing graph 203: extracting triangles... + [A123] Found 49 triangles in graph 203 + [A123] Processing graph 204: extracting triangles... + [A123] Found 15 triangles in graph 204 + [A123] Processing graph 205: extracting triangles... + [A123] Found 114 triangles in graph 205 + [A123] Processing graph 206: extracting triangles... + [A123] Found 792 triangles in graph 206 + [A123] Processing graph 207: extracting triangles... + [A123] Found 65 triangles in graph 207 + [A123] Processing graph 208: extracting triangles... + [A123] Found 222 triangles in graph 208 + [A123] Processing graph 209: extracting triangles... + [A123] Found 95 triangles in graph 209 + [A123] Processed 210/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 210: extracting triangles... + [A123] Found 58 triangles in graph 210 + [A123] Processing graph 211: extracting triangles... + [A123] Found 51 triangles in graph 211 + [A123] Processing graph 212: extracting triangles... + [A123] Found 5 triangles in graph 212 + [A123] Processing graph 213: extracting triangles... + [A123] Found 25 triangles in graph 213 + [A123] Processing graph 214: extracting triangles... + [A123] Found 26 triangles in graph 214 + [A123] Processing graph 215: extracting triangles... + [A123] Found 6 triangles in graph 215 + [A123] Processing graph 216: extracting triangles... + [A123] Found 38 triangles in graph 216 + [A123] Processing graph 217: extracting triangles... + [A123] Found 27 triangles in graph 217 + [A123] Processing graph 218: extracting triangles... + [A123] Found 210 triangles in graph 218 + [A123] Processing graph 219: extracting triangles... + [A123] Found 24 triangles in graph 219 + [A123] Processed 220/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 220: extracting triangles... + [A123] Found 135 triangles in graph 220 + [A123] Processing graph 221: extracting triangles... + [A123] Found 136 triangles in graph 221 + [A123] Processing graph 222: extracting triangles... + [A123] Found 64 triangles in graph 222 + [A123] Processing graph 223: extracting triangles... + [A123] Found 11 triangles in graph 223 + [A123] Processing graph 224: extracting triangles... + [A123] Found 43 triangles in graph 224 + [A123] Processing graph 225: extracting triangles... + [A123] Found 20 triangles in graph 225 + [A123] Processing graph 226: extracting triangles... + [A123] Found 177 triangles in graph 226 + [A123] Processing graph 227: extracting triangles... + [A123] Found 19 triangles in graph 227 + [A123] Processing graph 228: extracting triangles... + [A123] Found 92 triangles in graph 228 + [A123] Processing graph 229: extracting triangles... + [A123] Found 12 triangles in graph 229 + [A123] Processed 230/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 230: extracting triangles... + [A123] Found 418 triangles in graph 230 + [A123] Processing graph 231: extracting triangles... + [A123] Found 86 triangles in graph 231 + [A123] Processing graph 232: extracting triangles... + [A123] Found 152 triangles in graph 232 + [A123] Processing graph 233: extracting triangles... + [A123] Found 387 triangles in graph 233 + [A123] Processing graph 234: extracting triangles... + [A123] Found 4 triangles in graph 234 + [A123] Processing graph 235: extracting triangles... + [A123] Found 283 triangles in graph 235 + [A123] Processing graph 236: extracting triangles... + [A123] Found 44 triangles in graph 236 + [A123] Processing graph 237: extracting triangles... + [A123] Found 189 triangles in graph 237 + [A123] Processing graph 238: extracting triangles... + [A123] Found 120 triangles in graph 238 + [A123] Processing graph 239: extracting triangles... + [A123] Found 63 triangles in graph 239 + [A123] Processed 240/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 240: extracting triangles... + [A123] Found 868 triangles in graph 240 + [A123] Processing graph 241: extracting triangles... + [A123] Found 920 triangles in graph 241 + [A123] Processing graph 242: extracting triangles... + [A123] Found 680 triangles in graph 242 + [A123] Processing graph 243: extracting triangles... + [A123] Found 59 triangles in graph 243 + [A123] Processing graph 244: extracting triangles... + [A123] Found 266 triangles in graph 244 + [A123] Processing graph 245: extracting triangles... + [A123] Found 1680 triangles in graph 245 + [A123] Processing graph 246: extracting triangles... + [A123] Found 3037 triangles in graph 246 + [A123] Processing graph 247: extracting triangles... + [A123] Found 654 triangles in graph 247 + [A123] Processing graph 248: extracting triangles... + [A123] Found 2370 triangles in graph 248 + [A123] Processing graph 249: extracting triangles... + [A123] Found 17 triangles in graph 249 + [A123] Triangle extraction completed in 2s, found 335458 triangles + [A123] Creating triangle common-neighbors task... + [A123] Processing graph 175: extracting triangles... + [A123] Found 27 triangles in graph 175 + [A123] Processing graph 176: extracting triangles... + [A123] Found 148 triangles in graph 176 + [A123] Processing graph 177: extracting triangles... + [A123] Found 65 triangles in graph 177 + [A123] Processing graph 178: extracting triangles... + [A123] Found 21 triangles in graph 178 + [A123] Processing graph 179: extracting triangles... + [A123] Found 80 triangles in graph 179 + [A123] Processed 180/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 180: extracting triangles... + [A123] Found 20 triangles in graph 180 + [A123] Processing graph 181: extracting triangles... + [A123] Found 10 triangles in graph 181 + [A123] Processing graph 182: extracting triangles... + [A123] Found 35 triangles in graph 182 + [A123] Processing graph 183: extracting triangles... + [A123] Found 32 triangles in graph 183 + [A123] Processing graph 184: extracting triangles... + [A123] Found 197 triangles in graph 184 + [A123] Processing graph 185: extracting triangles... + [A123] Found 26 triangles in graph 185 + [A123] Processing graph 186: extracting triangles... + [A123] Found 76 triangles in graph 186 + [A123] Processing graph 187: extracting triangles... + [A123] Found 54 triangles in graph 187 + [A123] Processing graph 188: extracting triangles... + [A123] Found 31 triangles in graph 188 + [A123] Processing graph 189: extracting triangles... + [A123] Found 22 triangles in graph 189 + [A123] Processed 190/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 190: extracting triangles... + [A123] Found 15 triangles in graph 190 + [A123] Processing graph 191: extracting triangles... + [A123] Found 107 triangles in graph 191 + [A123] Processing graph 192: extracting triangles... + [A123] Found 544 triangles in graph 192 + [A123] Processing graph 193: extracting triangles... + [A123] Found 29 triangles in graph 193 + [A123] Processing graph 194: extracting triangles... + [A123] Found 226 triangles in graph 194 + [A123] Processing graph 195: extracting triangles... + [A123] Found 298 triangles in graph 195 + [A123] Processing graph 196: extracting triangles... + [A123] Found 9 triangles in graph 196 + [A123] Processing graph 197: extracting triangles... + [A123] Found 5 triangles in graph 197 + [A123] Processing graph 198: extracting triangles... + [A123] Found 11 triangles in graph 198 + [A123] Processing graph 199: extracting triangles... + [A123] Found 166 triangles in graph 199 + [A123] Processed 200/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 200: extracting triangles... + [A123] Found 162 triangles in graph 200 + [A123] Processing graph 201: extracting triangles... + [A123] Found 88 triangles in graph 201 + [A123] Processing graph 202: extracting triangles... + [A123] Found 139 triangles in graph 202 + [A123] Processing graph 203: extracting triangles... + [A123] Found 49 triangles in graph 203 + [A123] Processing graph 204: extracting triangles... + [A123] Found 15 triangles in graph 204 + [A123] Processing graph 205: extracting triangles... + [A123] Found 114 triangles in graph 205 + [A123] Processing graph 206: extracting triangles... + [A123] Found 792 triangles in graph 206 + [A123] Processing graph 207: extracting triangles... + [A123] Found 65 triangles in graph 207 + [A123] Processing graph 208: extracting triangles... + [A123] Found 222 triangles in graph 208 + [A123] Processing graph 209: extracting triangles... + [A123] Found 95 triangles in graph 209 + [A123] Processed 210/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 210: extracting triangles... + [A123] Found 58 triangles in graph 210 + [A123] Processing graph 211: extracting triangles... + [A123] Found 51 triangles in graph 211 + [A123] Processing graph 212: extracting triangles... + [A123] Found 5 triangles in graph 212 + [A123] Processing graph 213: extracting triangles... + [A123] Found 25 triangles in graph 213 + [A123] Processing graph 214: extracting triangles... + [A123] Found 26 triangles in graph 214 + [A123] Processing graph 215: extracting triangles... + [A123] Found 6 triangles in graph 215 + [A123] Processing graph 216: extracting triangles... + [A123] Found 38 triangles in graph 216 + [A123] Processing graph 217: extracting triangles... + [A123] Found 27 triangles in graph 217 + [A123] Processing graph 218: extracting triangles... + [A123] Found 210 triangles in graph 218 + [A123] Processing graph 219: extracting triangles... + [A123] Found 24 triangles in graph 219 + [A123] Processed 220/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 220: extracting triangles... + [A123] Found 135 triangles in graph 220 + [A123] Processing graph 221: extracting triangles... + [A123] Found 136 triangles in graph 221 + [A123] Processing graph 222: extracting triangles... + [A123] Found 64 triangles in graph 222 + [A123] Processing graph 223: extracting triangles... + [A123] Found 11 triangles in graph 223 + [A123] Processing graph 224: extracting triangles... + [A123] Found 43 triangles in graph 224 + [A123] Processing graph 225: extracting triangles... + [A123] Found 20 triangles in graph 225 + [A123] Processing graph 226: extracting triangles... + [A123] Found 177 triangles in graph 226 + [A123] Processing graph 227: extracting triangles... + [A123] Found 19 triangles in graph 227 + [A123] Processing graph 228: extracting triangles... + [A123] Found 92 triangles in graph 228 + [A123] Processing graph 229: extracting triangles... + [A123] Found 12 triangles in graph 229 + [A123] Processed 230/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 230: extracting triangles... + [A123] Found 418 triangles in graph 230 + [A123] Processing graph 231: extracting triangles... + [A123] Found 86 triangles in graph 231 + [A123] Processing graph 232: extracting triangles... + [A123] Found 152 triangles in graph 232 + [A123] Processing graph 233: extracting triangles... + [A123] Found 387 triangles in graph 233 + [A123] Processing graph 234: extracting triangles... + [A123] Found 4 triangles in graph 234 + [A123] Processing graph 235: extracting triangles... + [A123] Found 283 triangles in graph 235 + [A123] Processing graph 236: extracting triangles... + [A123] Found 44 triangles in graph 236 + [A123] Processing graph 237: extracting triangles... + [A123] Found 189 triangles in graph 237 + [A123] Processing graph 238: extracting triangles... + [A123] Found 120 triangles in graph 238 + [A123] Processing graph 239: extracting triangles... + [A123] Found 63 triangles in graph 239 + [A123] Processed 240/250 graphs (2s elapsed, ~0s remaining)... + [A123] Processing graph 240: extracting triangles... + [A123] Found 868 triangles in graph 240 + [A123] Processing graph 241: extracting triangles... + [A123] Found 920 triangles in graph 241 + [A123] Processing graph 242: extracting triangles... + [A123] Found 680 triangles in graph 242 + [A123] Processing graph 243: extracting triangles... + [A123] Found 59 triangles in graph 243 + [A123] Processing graph 244: extracting triangles... + [A123] Found 266 triangles in graph 244 + [A123] Processing graph 245: extracting triangles... + [A123] Found 1680 triangles in graph 245 + [A123] Processing graph 246: extracting triangles... + [A123] Found 3037 triangles in graph 246 + [A123] Processing graph 247: extracting triangles... + [A123] Found 654 triangles in graph 247 + [A123] Processing graph 248: extracting triangles... + [A123] Found 2370 triangles in graph 248 + [A123] Processing graph 249: extracting triangles... + [A123] Found 17 triangles in graph 249 + [A123] Triangle extraction completed in 2s, found 335458 triangles + [A123] Creating triangle common-neighbors task... + [A123] Created 335458 triangle CN samples + [A123] Collating 335458 triangle CN samples... + [A123] Created 335458 triangle CN samples + [A123] Collating 335458 triangle CN samples... + [A123] Saving triangle CN dataset to data/a123_cortex_m/processed/data_triangles_common_neighbors.pt... + [A123] Triangle CN dataset saved! + [A123] Processing complete! + [A123] Saving triangle CN dataset to data/a123_cortex_m/processed/data_triangles_common_neighbors.pt... + [A123] Triangle CN dataset saved! + [A123] Processing complete! + + + Done! + Processing... + + + [A123 Loader] Loaded triangle common-neighbours task dataset + Dataset loaded: 335458 samples + Task 'triangle_common_neighbors' uses triangle-level features (no edge_index) + Skipping lifting transformation for triangle data + + + Done! + + + Dataset splits created: + Train: 167729 samples + Val: 83864 samples + Test: 83865 samples + Datasets and datamodule ready + + + +```python +def undersample_majority_class(dataset, target_samples_per_class=100, random_state=42): + """ + Undersample all classes to a target number of samples per class. + + Parameters + ---------- + dataset : DataloadDataset + Dataset to undersample + target_samples_per_class : int + Target number of samples per class (default: 100) + random_state : int + Random seed for reproducibility + + Returns + ------- + DataloadDataset + Undersampled dataset + """ + np.random.seed(random_state) + + # Handle DataloadDataset which returns (values, keys) tuples + labels = [] + for item in dataset: + # The dataset returns (values_list, keys_list) + if isinstance(item, (list, tuple)) and len(item) == 2: + values, keys = item + # The 'y' label is the last value in the list + y = values[-1] + else: + # Fallback: try to access .y attribute + if hasattr(item, 'y'): + y = item.y + else: + continue # Skip if we can't extract label + + # Convert tensor to scalar + if hasattr(y, 'item'): + labels.append(int(y.item())) + elif hasattr(y, '__len__') and len(y) == 1: + # Single-element tensor or array + labels.append(int(y[0])) + else: + labels.append(int(y)) + + # Check if we extracted any labels + if len(labels) == 0: + raise ValueError(f"No labels extracted from dataset of size {len(dataset)}") + + labels = np.array(labels) + unique_labels, counts = np.unique(labels, return_counts=True) + + print(f"Original class distribution:") + for label, count in zip(unique_labels, counts): + print(f" Class {label}: {count} samples") + + # Get indices for each class + indices_by_class = {label: np.where(labels == label)[0] for label in unique_labels} + + # Undersample each class to target_samples_per_class (or fewer if class has fewer samples) + undersampled_indices = [] + for label in unique_labels: + indices = indices_by_class[label] + # Select up to target_samples_per_class indices from this class + actual_samples = min(len(indices), target_samples_per_class) + selected = np.random.choice(indices, size=actual_samples, replace=False) + undersampled_indices.extend(selected) + + # Shuffle the final indices + undersampled_indices = np.random.permutation(undersampled_indices) + + # Create subset of dataset + from torch.utils.data import Subset + undersampled_dataset = Subset(dataset, undersampled_indices) + + # Get new label distribution + new_labels = labels[undersampled_indices] + new_unique, new_counts = np.unique(new_labels, return_counts=True) + + print(f"\nAfter undersampling to {target_samples_per_class} per class:") + + for label, count in zip(new_unique, new_counts): + print(f" Class {label}: {count} samples") + + imbalance_ratio_before = counts.max() / counts.min() + imbalance_ratio_after = new_counts.max() / new_counts.min() + print(f"\nImbalance ratio: {imbalance_ratio_before:.2f} → {imbalance_ratio_after:.2f}") + print(f"Dataset size: {len(dataset)} → {len(undersampled_dataset)} samples\n") + + return undersampled_dataset + +# Apply undersampling to training set +print("Undersampling training set...") +dataset_train = undersample_majority_class(dataset_train, target_samples_per_class=200, random_state=0) + +# Optionally also undersample validation set for consistency +print("Undersampling validation set...") +dataset_val = undersample_majority_class(dataset_val, target_samples_per_class=200, random_state=0) + +# Optionally also undersample test set for consistency +print("Undersampling test set...") +dataset_test = undersample_majority_class(dataset_test, target_samples_per_class=200, random_state=0) + +# Recreate datamodule with undersampled datasets +datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32) + +print('Datasets rebalanced and datamodule recreated') +``` + + Undersampling training set... + Original class distribution: + Class 0: 507 samples + Class 1: 1119 samples + Class 2: 1574 samples + Class 3: 1804 samples + Class 4: 1946 samples + Class 5: 2081 samples + Class 6: 2203 samples + Class 7: 2157 samples + Class 8: 154338 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 304.41 → 1.00 + Dataset size: 167729 → 1800 samples + + Undersampling validation set... + Original class distribution: + Class 0: 507 samples + Class 1: 1119 samples + Class 2: 1574 samples + Class 3: 1804 samples + Class 4: 1946 samples + Class 5: 2081 samples + Class 6: 2203 samples + Class 7: 2157 samples + Class 8: 154338 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 304.41 → 1.00 + Dataset size: 167729 → 1800 samples + + Undersampling validation set... + Original class distribution: + Class 0: 288 samples + Class 1: 542 samples + Class 2: 788 samples + Class 3: 966 samples + Class 4: 988 samples + Class 5: 1043 samples + Class 6: 1078 samples + Class 7: 1116 samples + Class 8: 77055 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 267.55 → 1.00 + Dataset size: 83864 → 1800 samples + + Undersampling test set... + Original class distribution: + Class 0: 288 samples + Class 1: 542 samples + Class 2: 788 samples + Class 3: 966 samples + Class 4: 988 samples + Class 5: 1043 samples + Class 6: 1078 samples + Class 7: 1116 samples + Class 8: 77055 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 267.55 → 1.00 + Dataset size: 83864 → 1800 samples + + Undersampling test set... + Original class distribution: + Class 0: 243 samples + Class 1: 541 samples + Class 2: 760 samples + Class 3: 909 samples + Class 4: 983 samples + Class 5: 1023 samples + Class 6: 1072 samples + Class 7: 1171 samples + Class 8: 77163 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 317.54 → 1.00 + Dataset size: 83865 → 1800 samples + + Datasets rebalanced and datamodule recreated + Original class distribution: + Class 0: 243 samples + Class 1: 541 samples + Class 2: 760 samples + Class 3: 909 samples + Class 4: 983 samples + Class 5: 1023 samples + Class 6: 1072 samples + Class 7: 1171 samples + Class 8: 77163 samples + + After undersampling to 200 per class: + Class 0: 200 samples + Class 1: 200 samples + Class 2: 200 samples + Class 3: 200 samples + Class 4: 200 samples + Class 5: 200 samples + Class 6: 200 samples + Class 7: 200 samples + Class 8: 200 samples + + Imbalance ratio: 317.54 → 1.00 + Dataset size: 83865 → 1800 samples + + Datasets rebalanced and datamodule recreated + + +## 4) Backbone definition + +We implement a tiny backbone as a `pl.LightningModule` which computes node and hyperedge features: $X_1 = B_1 dot X_0$ and applies two linear layers with ReLU. + + +```python +class MyBackbone(pl.LightningModule): + def __init__(self, dim_hidden): + super().__init__() + self.linear_0 = torch.nn.Linear(dim_hidden, dim_hidden) + self.linear_1 = torch.nn.Linear(dim_hidden, dim_hidden) + + def forward(self, batch): + # batch.x_0: node features (dense tensor of shape [N, dim_hidden]) + # batch.incidence_hyperedges: sparse incidence matrix with shape [m, n] or [n, m] depending on preprocessor convention + x_0 = batch.x_0 + incidence_hyperedges = getattr(batch, 'incidence_hyperedges', None) + if incidence_hyperedges is None: + # fallback: try incidence as batch.incidence if available + incidence_hyperedges = getattr(batch, 'incidence', None) + + # compute hyperedge features X_1 = B_1 dot X_0 (we assume B_1 is sparse and transposed appropriately) + x_1 = None + if incidence_hyperedges is not None: + try: + x_1 = torch.sparse.mm(incidence_hyperedges, x_0) + except Exception: + # if orientation differs, try transpose + x_1 = torch.sparse.mm(incidence_hyperedges.T, x_0) + else: + # no incidence available: create a zero hyperedge feature placeholder + x_1 = torch.zeros_like(x_0) + + x_0 = self.linear_0(x_0) + x_0 = torch.relu(x_0) + + x_1 = self.linear_1(x_1) + x_1 = torch.relu(x_1) + + model_out = {'labels': batch.y, 'batch_0': getattr(batch, 'batch_0', None)} + model_out['x_0'] = x_0 + model_out['hyperedge'] = x_1 + return model_out + +print('Backbone defined') +``` + + Backbone defined + + + +```python +# 5) Model initialization (components) +backbone = MyBackbone(dim_hidden) +readout = PropagateSignalDown(**readout_config) +loss = TBLoss(**loss_config) +feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels], out_channels=dim_hidden) +evaluator = TBEvaluator(**evaluator_config) +optimizer = TBOptimizer(**optimizer_config) + +print('Components instantiated') +``` + + Components instantiated + + + +```python +# 6) Instantiate TBModel +model = TBModel(backbone=backbone, + backbone_wrapper=None, + readout=readout, + loss=loss, + feature_encoder=feature_encoder, + evaluator=evaluator, + optimizer=optimizer, + compile=False) + +# Print a short summary (repr) to verify construction +print(model) +``` + + TBModel(backbone=MyBackbone( + (linear_0): Linear(in_features=16, out_features=16, bias=True) + (linear_1): Linear(in_features=16, out_features=16, bias=True) + ), readout=PropagateSignalDown(num_cell_dimensions=0, self.hidden_dim=16, readout_name=PropagateSignalDown, loss=TBLoss(losses=[DatasetLoss(task=classification, loss_type=cross_entropy)]), feature_encoder=AllCellFeatureEncoder(in_channels=[3], out_channels=16, dimensions=range(0, 1))) + + + +```python +# 7) Training loop (Lightning trainer) +# Suppress some warnings for cleaner output +import warnings +warnings.filterwarnings('ignore', category=UserWarning, module='torchmetrics') + +trainer = pl.Trainer( + max_epochs=50, # reduced for faster iteration + accelerator='cpu', + enable_progress_bar=True, + log_every_n_steps=1, + enable_model_summary=False, # skip the model summary printout +) +trainer.fit(model, datamodule) +train_metrics = trainer.callback_metrics + +print('\nTraining finished. Collected metrics:') +for key, val in train_metrics.items(): + try: + print(f'{key:25s} {float(val):.4f}') + except Exception: + print(key, val) +``` + + GPU available: True (mps), used: False + TPU available: False, using: 0 TPU cores + HPU available: False, using: 0 HPUs + /Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. + TPU available: False, using: 0 TPU cores + HPU available: False, using: 0 HPUs + /Users/mariayuffa/anaconda3/envs/tb3/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. + + + + Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Test metric DataLoader 0 ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ test/accuracy 0.1111111119389534 │ +│ test/f1 0.02222222276031971 │ +│ test/loss 2.197321653366089 │ +│ test/precision 0.012345679104328156 │ +│ test/recall 0.1111111119389534 │ +└───────────────────────────┴───────────────────────────┘ + + + + + + Test metrics: + test/loss 2.1973 + test/accuracy 0.1111 + test/f1 0.0222 + test/precision 0.0123 + test/recall 0.1111 + + +## Running Other Tasks + +To run a different task, modify the `TASK_NAME` variable in cell 4 (configurations) to one of: +- `graph_classification` (default): Predict frequency bin from graph structure +- `triangle_classification`: Classify topological role of triangles (9 embedding × weight classes) +- `triangle_common_neighbors`: Predict number of common neighbors for each triangle + +Then re-run the configuration cell and subsequent cells. The dataset will automatically load the appropriate task variant, and the model will be configured with the correct number of output classes (9 for all tasks). + +### Task Details: + +**Task 1: Graph-level Classification** +- Input: Graph structure with node features (mean correlation, std correlation, noise diagonal) +- Output: Frequency bin (0-8) representing the best frequency +- Level: Graph-level prediction + +**Task 2: Triangle Classification** +- Input: Topological features of triangles (3 edge weights from correlation matrix) +- Output: Triangle role classification (9 classes based on embedding × weight): + - Embedding classes: Core (many common neighbors), Bridge (some), Isolated (few) + - Weight classes: Strong (high correlation), Medium, Weak (low correlation) +- Level: Triangle (motif) level prediction + +**Task 3: Triangle Common-Neighbors** +- Input: Triangle node degrees (structural features) +- Output: Number of common neighbors (0-8, mapping neighbors count to class) +- Level: Triangle (motif) level prediction