diff --git a/README.md b/README.md index b0392225..4508f751 100644 --- a/README.md +++ b/README.md @@ -94,10 +94,11 @@ Experiments below assume an 8-GPU setup. # Initialize submodules git submodule update --init --recursive -# ARC-1 -python dataset/build_arc_dataset.py # ARC offical + ConceptARC, 960 examples +# ConceptARC (train only) + ARC-1 +python dataset/build_arc_dataset.py + # ARC-2 -python dataset/build_arc_dataset.py --dataset-dirs dataset/raw-data/ARC-AGI-2/data --output-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples +python dataset/build_arc_dataset.py --raw-dataset-dirs dataset/raw-data/ARC-AGI-2/data --processed-dataset-dir data/arc-2-aug-1000 # ARC-2 official, 1120 examples # Sudoku-Extreme python dataset/build_sudoku_dataset.py # Full version diff --git a/dataset/build_arc_dataset.py b/dataset/build_arc_dataset.py index 2da5703e..d72b5b48 100644 --- a/dataset/build_arc_dataset.py +++ b/dataset/build_arc_dataset.py @@ -1,291 +1,383 @@ -from typing import List, Optional, Tuple, Dict -from dataclasses import dataclass +from concurrent.futures import ProcessPoolExecutor, as_completed +from enum import Enum +from functools import cached_property, partial +import itertools from pathlib import Path -import os -import json +import random +from typing import Annotated, Final, Literal , get_args + import hashlib +import json import numpy as np -from glob import glob +from numpy.typing import NDArray + +import tqdm -from argdantic import ArgParser -from pydantic import BaseModel +from pydantic import BaseModel, BeforeValidator, Field, ConfigDict, TypeAdapter from common import PuzzleDatasetMetadata, dihedral_transform -cli = ArgParser() +ARC_MAX_GRID_SIZE: Final[int] = 30 +ARC_MAX_COLOR_VALUE: Final[int] = 9 +BLANK_PUZZLE_ID: Final[str] = "" +BLACK_COLOR: Final[int] = 0 +N_PADDING_TOKENS: Final[int] = 1 +N_EOS_TOKENS: Final[int] = 1 +DIHEDRAL_SYMMETRIES: Final[int] = 8 +PAD_TOKEN: Final[int] = 0 +END_OF_SEQUENCE_TOKEN: Final[int] = 1 +COLOR_OFFSET_TOKEN: Final[int] = 2 + +# TODO: I've removed the "set" logic as it was unused +SET_NAME: Final[str] = "all" + + +class ArcIOKey(str, Enum): + INPUT = "input" + OUTPUT = "output" -class DataProcessConfig(BaseModel): - # ARC-1 - dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] - output_dir: str = "data/arc-aug-1000" - - # ARC-2 - # dataset_dirs: List[str] = ["dataset/raw-data/ARC-AGI-2/data"] - # output_dir: str = "data/arc-2-aug-1000" +RawArcSplit = Annotated[ + Literal["training", "evaluation"], + BeforeValidator(lambda x: "evaluation" if x == "evaluation" else "training"), # ConceptARC data assumed to be training +] +RawArcSplitAdapter: TypeAdapter[RawArcSplit] = TypeAdapter(RawArcSplit) + +ProcessedArcSplit = Literal["train", "test"] +ARCExampleType = Literal["train", "test"] +RawPuzzle = dict[ARCExampleType, list[dict[ArcIOKey, list[list[int]]]]] + +# For clarity +GridArray = NDArray[np.uint8] +FlatArray = NDArray[np.uint8] + + + +class ARCDatasetBuildConfig(BaseModel): seed: int = 42 - num_aug: int = 1000 - - -ARCMaxGridSize = 30 -ARCAugmentRetriesFactor = 5 - + num_aug: int = Field(default=1000, ge=0) + augment_retries_factor: int = Field(default=5, ge=1) + + raw_dataset_dirs: list[str | Path] = ["dataset/raw-data/ARC-AGI/data", "dataset/raw-data/ConceptARC/corpus"] + + processed_dataset_dir: str = "data/arc-aug-1000" + identifiers_filename: str = "identifiers" + metadata_filename: str = "dataset" + + +class PuzzleExample(BaseModel): + example_type: ARCExampleType + input_grid: GridArray + output_grid: GridArray + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def pad(self) -> "PuzzleExample": + return self.model_copy(update={ + "input_grid": _pad_grid(self.input_grid), + "output_grid": _pad_grid(self.output_grid), + }) + -@dataclass -class ARCPuzzle: +def _grid_to_bytes(grid: GridArray) -> bytes: + return ( + grid.shape[0].to_bytes(1, "little") + + grid.shape[1].to_bytes(1, "little") + + grid.tobytes() + ) + + +class ARCPuzzle(BaseModel): id: str + split: RawArcSplit + examples: list[PuzzleExample] - examples: List[Tuple[np.ndarray, np.ndarray]] + model_config = ConfigDict(arbitrary_types_allowed=True) - -def arc_grid_to_np(grid: List[List[int]]): - arr = np.array(grid) - - # Shape check - assert arr.ndim == 2 - assert arr.shape[0] <= ARCMaxGridSize and arr.shape[1] <= ARCMaxGridSize - # Element check - assert np.all((arr >= 0) & (arr <= 9)) - return arr.astype(np.uint8) - - -def np_grid_to_seq_translational_augment(inp: np.ndarray, out: np.ndarray, do_translation: bool): - # PAD: 0, : 1, digits: 2 ... 11 - # Compute random top-left pad - if do_translation: - pad_r = np.random.randint(0, ARCMaxGridSize - max(inp.shape[0], out.shape[0]) + 1) - pad_c = np.random.randint(0, ARCMaxGridSize - max(inp.shape[1], out.shape[1]) + 1) - else: - pad_r = pad_c = 0 - - # Pad grid - result = [] - for grid in [inp, out]: - nrow, ncol = grid.shape - grid = np.pad(grid + 2, ((pad_r, ARCMaxGridSize - pad_r - nrow), (pad_c, ARCMaxGridSize - pad_c - ncol)), constant_values=0) - - # Add - eos_row, eos_col = pad_r + nrow, pad_c + ncol - if eos_row < ARCMaxGridSize: - grid[eos_row, pad_c:eos_col] = 1 - if eos_col < ARCMaxGridSize: - grid[pad_r:eos_row, eos_col] = 1 - - result.append(grid.flatten()) - - return result - - -def puzzle_hash(puzzle: dict): - # Hash the puzzle for checking equivalence - def _grid_hash(grid: np.ndarray): - buffer = [x.to_bytes(1) for x in grid.shape] - buffer.append(grid.tobytes()) + @cached_property + def hash(self) -> str: + example_buffers = sorted( + _grid_to_bytes(example.input_grid) + _grid_to_bytes(example.output_grid) + for example in self.examples + ) + + hasher = hashlib.blake2b(digest_size=8) + for buf in example_buffers: + hasher.update(buf) + + return hasher.hexdigest() + + @classmethod + def from_raw_file(cls, filepath: Path, arc_split: RawArcSplit, puzzle_id: str) -> "ARCPuzzle": + with open(filepath) as f: + data = json.load(f) - return hashlib.sha256(b"".join(buffer)).hexdigest() + valid_example_types = get_args(ARCExampleType) + raw_puzzle_data = { + key: data[key] for key in valid_example_types if key in data + } + return cls._from_raw(raw_puzzle_data, arc_split, puzzle_id) + + @classmethod + def _from_raw(cls, raw_puzzle: RawPuzzle, split: RawArcSplit, puzzle_id: str) -> "ARCPuzzle": + return cls( + id=puzzle_id, + split=split, + examples=[ + PuzzleExample( + example_type=example_type, + input_grid=_parse_raw_grid(example[ArcIOKey.INPUT]), + output_grid=_parse_raw_grid(example[ArcIOKey.OUTPUT]), + ) + for example_type, examples in raw_puzzle.items() + for example in examples + ], + ) + + +def _parse_raw_grid(raw_grid: list[list[int]]) -> GridArray: + arr = np.array(raw_grid, dtype=np.uint8) + + if arr.ndim != 2: + raise ValueError(f"Grid must be 2D, got {arr.ndim}D") + if arr.shape[0] > ARC_MAX_GRID_SIZE or arr.shape[1] > ARC_MAX_GRID_SIZE: + raise ValueError(f"Grid size {arr.shape} exceeds maximum {ARC_MAX_GRID_SIZE}") + if not np.all((arr >= BLACK_COLOR) & (arr <= ARC_MAX_COLOR_VALUE)): + raise ValueError(f"Grid values must be in range [0, {ARC_MAX_COLOR_VALUE}]") + + return arr + + +def _pad_grid(grid: GridArray, row_padding: int = 0, col_padding: int = 0) -> GridArray: + n_rows, n_cols = grid.shape - hashes = [] - for example_type, example in puzzle.items(): - for input, label in example.examples: - hashes.append(f"{_grid_hash(input)}|{_grid_hash(label)}") - - hashes.sort() - return hashlib.sha256("|".join(hashes).encode()).hexdigest() - - -def convert_single_arc_puzzle(results: dict, default_name: str, puzzle: dict, aug_count: int, dest_mapping: Dict[str, Tuple[str, str]]): - # Remove "name" - name = puzzle.pop("name", default_name) + padded = np.full((ARC_MAX_GRID_SIZE, ARC_MAX_GRID_SIZE), PAD_TOKEN, dtype=np.uint8) + padded[row_padding:row_padding + n_rows, col_padding:col_padding + n_cols] = grid + COLOR_OFFSET_TOKEN + + eos_row, eos_col = row_padding + n_rows, col_padding + n_cols + if eos_row < ARC_MAX_GRID_SIZE: + padded[eos_row, col_padding:eos_col] = END_OF_SEQUENCE_TOKEN + if eos_col < ARC_MAX_GRID_SIZE: + padded[row_padding:eos_row, eos_col] = END_OF_SEQUENCE_TOKEN - # Convert - dests = set(dest_mapping.values()) - converted = {dest: ARCPuzzle(name, []) for dest in dests} - for example_type, examples in puzzle.items(): - dest = dest_mapping[example_type] - converted[dest].examples.extend([(arc_grid_to_np(example["input"]), arc_grid_to_np(example["output"])) for example in examples]) - - group = [converted] + return padded + + +def _apply_translational_augment(example: PuzzleExample) -> PuzzleExample: + max_rows = max(example.input_grid.shape[0], example.output_grid.shape[0]) + max_cols = max(example.input_grid.shape[1], example.output_grid.shape[1]) + row_padding = np.random.randint(0, ARC_MAX_GRID_SIZE - max_rows + 1) + col_padding = np.random.randint(0, ARC_MAX_GRID_SIZE - max_cols + 1) + + return PuzzleExample( + example_type=example.example_type, + input_grid=_pad_grid(example.input_grid, row_padding, col_padding), + output_grid=_pad_grid(example.output_grid, row_padding, col_padding), + ) + + +def _apply_dihedral_transform_augment(puzzle: ARCPuzzle) -> ARCPuzzle: + trans_id = np.random.randint(0, DIHEDRAL_SYMMETRIES) + color_map = np.concatenate( + [ + np.arange(BLACK_COLOR, BLACK_COLOR + 1, dtype=np.uint8), + np.random.permutation(np.arange(BLACK_COLOR + 1, ARC_MAX_COLOR_VALUE + 1, dtype=np.uint8)), + ] + ) + aug_repr = f"t{trans_id}_{hashlib.blake2b(color_map.tobytes(), digest_size=8).hexdigest()}" + + augmented_examples = [ + PuzzleExample( + example_type=example.example_type, + input_grid=color_map[dihedral_transform(example.input_grid, trans_id)], + output_grid=color_map[dihedral_transform(example.output_grid, trans_id)], + ) + for example in puzzle.examples + ] + + return ARCPuzzle(id=f"{puzzle.id}_{aug_repr}", split=puzzle.split, examples=augmented_examples) + + +def _generate_puzzle_augmentations(puzzle: ARCPuzzle, aug_count: int, augment_retries_factor: int) -> list[ARCPuzzle]: + group = [puzzle] + + seen_puzzles = {puzzle.hash} + for _ in range(augment_retries_factor * aug_count): + if (augmented := _apply_dihedral_transform_augment(puzzle)).hash not in seen_puzzles: + seen_puzzles.add(augmented.hash) + group.append(augmented) + if len(group) >= aug_count + 1: + break + + no_translation_idx = np.random.randint(0, len(puzzle.examples)) + return [ + ARCPuzzle( + id=puzzle.id, + split=puzzle.split, + examples=[ + _apply_translational_augment(example) if idx != no_translation_idx and puzzle.split == "training" else example.pad() + for idx, example in enumerate(puzzle.examples) + ], + ) + for puzzle in group + ] + + +def _load_split_raw_puzzles(arc_split: RawArcSplit, dirpath: Path) -> list[ARCPuzzle]: + puzzles = [ + ARCPuzzle.from_raw_file( + filepath=Path(filename), + arc_split=arc_split, + puzzle_id=Path(filename).stem, + ) + for filename in Path(dirpath).glob("*.json") + ] + + random.shuffle(puzzles) + return puzzles + + +def _load_all_raw_puzzles(dataset_dir: str | Path) -> list[ARCPuzzle]: + puzzles = [] + + subdirs = (d for d in Path(dataset_dir).iterdir() if d.is_dir()) + for split_dir in subdirs: + arc_split: RawArcSplit = RawArcSplitAdapter.validate_python(split_dir.name) + puzzles.extend(_load_split_raw_puzzles(arc_split, split_dir)) + return puzzles + +def _process_single_puzzle(puzzle: ARCPuzzle, config: ARCDatasetBuildConfig) -> tuple[RawArcSplit, list[ARCPuzzle]]: + augmented_puzzles = _generate_puzzle_augmentations( + puzzle, config.num_aug, config.augment_retries_factor + ) + return puzzle.split, augmented_puzzles + + +def _split_puzzle_augmentations( + original_puzzle_split: RawArcSplit, + puzzle_augmenations: list[ARCPuzzle] +) -> tuple[list[ARCPuzzle], list[ARCPuzzle]]: + if original_puzzle_split == "training": + return puzzle_augmenations, [] + + def create_split(example_type: ARCExampleType) -> list[ARCPuzzle]: + return [ + p.model_copy(update={"examples": examples}) + for p in puzzle_augmenations + if (examples := [ex for ex in p.examples if ex.example_type == example_type]) + ] + + return create_split("train"), create_split("test") + + +def _process_arcagi_dataset( + dataset_paths: list[str | Path], config: ARCDatasetBuildConfig +) -> dict[ProcessedArcSplit, list[list[ARCPuzzle]]]: + puzzles = [] + for dataset_path in dataset_paths: + if Path(dataset_path).exists(): + dataset_puzzles = _load_all_raw_puzzles(dataset_path) + puzzles.extend(dataset_puzzles) + print(f"[{dataset_path}] loaded {len(dataset_puzzles)} puzzles") - # Augment - if aug_count > 0: - hashes = {puzzle_hash(converted)} - - for _trial in range(ARCAugmentRetriesFactor * aug_count): - # Augment plan - trans_id = np.random.randint(0, 8) - mapping = np.concatenate([np.arange(0, 1, dtype=np.uint8), np.random.permutation(np.arange(1, 10, dtype=np.uint8))]) # Permute colors, Excluding "0" (black) - - aug_repr = f"t{trans_id}_{''.join(str(x) for x in mapping)}" - - def _map_grid(grid: np.ndarray): - return dihedral_transform(mapping[grid], trans_id) - - # Check duplicate - augmented = {dest: ARCPuzzle(f"{puzzle.id}_{aug_repr}", [(_map_grid(input), _map_grid(label)) for (input, label) in puzzle.examples]) for dest, puzzle in converted.items()} - h = puzzle_hash(augmented) - if h not in hashes: - hashes.add(h) - group.append(augmented) - - if len(group) >= aug_count + 1: - break - - if len(group) < aug_count + 1: - print (f"[Puzzle {name}] augmentation not full, only {len(group)}") - - # Append - for dest in dests: - # Convert the examples - dest_split, dest_set = dest - - results.setdefault(dest_split, {}) - results[dest_split].setdefault(dest_set, []) - results[dest_split][dest_set].append([converted[dest] for converted in group]) - - -def load_puzzles_arcagi(results: dict, dataset_path: str, config: DataProcessConfig): - train_examples_dest = ("train", "all") - test_examples_map = { - "evaluation": [(1.0, ("test", "all"))], - "_default": [(1.0, ("train", "all"))] - } + train_groups: list[list[ARCPuzzle]] = [] + test_groups: list[list[ARCPuzzle]] = [] + process_puzzle = partial(_process_single_puzzle, config=config) + + with ProcessPoolExecutor() as executor: + futures = [executor.submit(process_puzzle, puzzle) for puzzle in puzzles] + progress = tqdm.tqdm(as_completed(futures), total=len(puzzles), desc="Processing ARC puzzles") + for future in progress: + original_split, augmented_puzzles = future.result() + train_batch, test_batch = _split_puzzle_augmentations(original_split, augmented_puzzles) + train_groups.append(train_batch) + test_groups.append(test_batch) + + return {"train": train_groups, "test": test_groups} + + +def _build_identifier_map(data: dict[ProcessedArcSplit, list[list[ARCPuzzle]]]) -> dict[str, int]: + num_identifiers = 1 # First identifier is reserved for blank puzzle + identifier_map = {} + for puzzle in itertools.chain.from_iterable(itertools.chain.from_iterable(data.values())): + if puzzle.id in identifier_map: + continue + identifier_map[puzzle.id] = num_identifiers + num_identifiers += 1 + return identifier_map + + +def _save_split_data(split: ProcessedArcSplit, puzzles: list[list[ARCPuzzle]], identifier_map: dict[str, int], output_dir: Path, metadata_filename: str) -> None: + all_puzzles = list(itertools.chain.from_iterable(puzzles)) + all_examples = list(itertools.chain.from_iterable(p.examples for p in all_puzzles)) - total_puzzles = 0 - for subdir in os.scandir(dataset_path): - if subdir.is_dir(): - # Load all puzzles in this directory - puzzles = [] - for filename in glob(os.path.join(subdir.path, "*.json")): - with open(filename, "r") as f: - puzzles.append((Path(filename).stem, json.load(f))) - - # Shuffle puzzles - np.random.shuffle(puzzles) - - # Assign by fraction - for idx, (default_name, puzzle) in enumerate(puzzles): - fraction = idx / len(puzzles) - test_examples_dest = None - for f, dest in test_examples_map.get(subdir.name, test_examples_map["_default"]): - if fraction < f: - test_examples_dest = dest - break - - assert test_examples_dest is not None - - convert_single_arc_puzzle(results, default_name, puzzle, config.num_aug, {"train": train_examples_dest, "test": test_examples_dest}) - total_puzzles += 1 - - print (f"[{dataset_path}] total puzzles: {total_puzzles}") - - -def convert_dataset(config: DataProcessConfig): + flattened_inputs = np.stack([ex.input_grid.flatten() for ex in all_examples], 0) + flattened_labels = np.stack([ex.output_grid.flatten() for ex in all_examples], 0) + puzzle_identifiers = np.array([identifier_map[p.id] for p in all_puzzles], dtype=np.uint32) + puzzle_indices = np.cumsum([len(p.examples) for p in all_puzzles], dtype=np.uint32) + group_indices = np.cumsum([len(aug_puzzles) for aug_puzzles in puzzles], dtype=np.uint32) + + split_output_dir = output_dir / split + split_output_dir.mkdir(parents=True, exist_ok=True) + for name, array in ( + ("inputs", flattened_inputs), + ("labels", flattened_labels), + ("puzzle_identifiers", puzzle_identifiers), + ("puzzle_indices", puzzle_indices), + ("group_indices", group_indices), + ): + np.save(split_output_dir / f"{SET_NAME}__{name}.npy", array) + + total_puzzles = len(all_puzzles) + total_examples = puzzle_indices[-1] if total_puzzles > 0 else 0 + metadata = PuzzleDatasetMetadata( + seq_len=ARC_MAX_GRID_SIZE * ARC_MAX_GRID_SIZE, + vocab_size=(ARC_MAX_COLOR_VALUE + 1) + N_PADDING_TOKENS + N_EOS_TOKENS, + pad_id=PAD_TOKEN, + ignore_label_id=PAD_TOKEN, + blank_identifier_id=PAD_TOKEN, + num_puzzle_identifiers=len(identifier_map) + 1, # +1 for blank puzzle + total_groups=total_puzzles, + mean_puzzle_examples=(total_examples / total_puzzles if total_puzzles > 0 else 0), + sets=[SET_NAME], + ) + with open(output_dir / f"{metadata_filename}.json", "w") as f: + json.dump(metadata.model_dump(), f) + + print(f"Saved {split} data to {split_output_dir}") + + +def _save_identifier_list(identifier_map: dict[str, int], output_path: Path) -> None: + reverse_map = {v: k for k, v in identifier_map.items()} + identifiers = [reverse_map.get(i, BLANK_PUZZLE_ID) for i in range(len(identifier_map) + 1)] + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(identifiers, f) + + +def convert_dataset(config: ARCDatasetBuildConfig) -> None: np.random.seed(config.seed) - - # Read dataset - data = {} - for dataset_dir in config.dataset_dirs: - load_puzzles_arcagi(data, dataset_dir, config) - - # Map global puzzle identifiers - num_identifiers = 1 # 0 is blank - identifier_map = {} - for split_name, split in data.items(): - for subset_name, subset in split.items(): - for group in subset: - for puzzle in group: - if puzzle.id not in identifier_map: - identifier_map[puzzle.id] = num_identifiers - num_identifiers += 1 - - print (f"Total puzzle IDs (including ): {num_identifiers}") - - # Save - for split_name, split in data.items(): - os.makedirs(os.path.join(config.output_dir, split_name), exist_ok=True) - - # Translational augmentations - enable_translational_augment = split_name == "train" + random.seed(config.seed) + output_dir = Path(config.processed_dataset_dir) - # Statistics - total_examples = 0 - total_puzzles = 0 - total_groups = 0 - - for subset_name, subset in split.items(): - # Construct subset - results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]} - results["puzzle_indices"].append(0) - results["group_indices"].append(0) - - example_id = 0 - puzzle_id = 0 - - for group in subset: - for puzzle in group: - # Push puzzle - no_aug_id = np.random.randint(0, len(puzzle.examples)) - for _idx_ex, (inp, out) in enumerate(puzzle.examples): - inp, out = np_grid_to_seq_translational_augment(inp, out, do_translation=enable_translational_augment and _idx_ex != no_aug_id) - - results["inputs"].append(inp) - results["labels"].append(out) - example_id += 1 - - total_examples += 1 - - results["puzzle_indices"].append(example_id) - results["puzzle_identifiers"].append(identifier_map[puzzle.id]) - - puzzle_id += 1 - - total_puzzles += 1 - - # Push group - results["group_indices"].append(puzzle_id) - total_groups += 1 - - for k, v in results.items(): - if k in {"inputs", "labels"}: - v = np.stack(v, 0) - else: - v = np.array(v, dtype=np.int32) - - np.save(os.path.join(config.output_dir, split_name, f"{subset_name}__{k}.npy"), v) - - # Metadata - metadata = PuzzleDatasetMetadata( - seq_len=ARCMaxGridSize * ARCMaxGridSize, - vocab_size=10 + 2, # PAD + EOS + "0" ... "9" - - pad_id=0, - ignore_label_id=0, - - blank_identifier_id=0, - num_puzzle_identifiers=num_identifiers, - - total_groups=total_groups, - mean_puzzle_examples=total_examples / total_puzzles, - sets=list(split.keys()) - ) + per_split_puzzle_groups = _process_arcagi_dataset(config.raw_dataset_dirs, config) + print("Building identifier map") + identifier_map = _build_identifier_map(per_split_puzzle_groups) + print("Saving identifier list") + _save_identifier_list(identifier_map, output_dir / f"{config.identifiers_filename}.json") + for split, puzzles in per_split_puzzle_groups.items(): + _save_split_data(split, puzzles, identifier_map, output_dir, config.metadata_filename) - # Save metadata as JSON. - with open(os.path.join(config.output_dir, split_name, "dataset.json"), "w") as f: - json.dump(metadata.model_dump(), f) - - # Save IDs mapping - with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f: - ids_mapping = {v: k for k, v in identifier_map.items()} - - json.dump([ids_mapping.get(i, "") for i in range(num_identifiers)], f) +if __name__ == "__main__": + from argdantic import ArgParser -@cli.command(singleton=True) -def main(config: DataProcessConfig): - convert_dataset(config) + cli = ArgParser() + @cli.command(singleton=True) + def main(config: ARCDatasetBuildConfig): + convert_dataset(ARCDatasetBuildConfig(**config.model_dump())) -if __name__ == "__main__": cli()