diff --git a/tests/test_eval_cli.py b/tests/test_eval_cli.py index ccb87d4c..91a096c4 100644 --- a/tests/test_eval_cli.py +++ b/tests/test_eval_cli.py @@ -67,6 +67,8 @@ def __init__(self, api_key=None, base_url=None): }, verbose=False, save_dataset=False, + save_dataset_cache=False, + cache_dir="~/.cache/verifiers/", save_to_hf_hub=False, hf_hub_dataset_name="", ) @@ -114,6 +116,8 @@ def __init__(self, api_key=None, base_url=None): }, verbose=False, save_dataset=False, + save_dataset_cache=False, + cache_dir="~/.cache/verifiers/", save_to_hf_hub=False, hf_hub_dataset_name="", ) diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 7c92db11..31628ea4 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -5,6 +5,7 @@ import logging import time import uuid +import os from datetime import datetime from pathlib import Path @@ -35,6 +36,8 @@ def eval_environment( sampling_args: dict | None, verbose: bool, save_dataset: bool, + save_dataset_cache: bool, + cache_dir: str, save_to_hf_hub: bool, hf_hub_dataset_name: str, ): @@ -145,7 +148,7 @@ def eval_environment( out = f"r{i + 1}: {trials}" print(out) - if save_dataset or save_to_hf_hub: + if save_dataset or save_dataset_cache or save_to_hf_hub: ids = [i // rollouts_per_example for i in range(n * rollouts_per_example)] rewards = results.reward tasks = results.task @@ -198,6 +201,43 @@ def eval_environment( json.dump(metadata, f) logger.info(f"Saved dataset to {results_path}") + if save_dataset_cache: + if cache_dir: + cache_base = Path(cache_dir) + else: + cache_base = Path( + os.environ.get( + "VF_CACHE_DIR", + Path.home() / ".cache" / "verifiers" + ) + ) + cache_path = cache_base / "evals" / env_model_str / uuid_str + cache_path.mkdir(parents=True, exist_ok=True) + + dataset.to_json(cache_path / "results.jsonl") + with open(cache_path / "metadata.json", "w") as f: + json.dump(metadata, f) + + index_file = cache_base / "evals" / "index.json" + index_data = {} + if index_file.exists(): + with open(index_file, "r") as f: + index_data = json.load(f) + + run_key = f"{env_model_str}/{uuid_str}" + index_data[run_key] = { + "env": env, + "model": model, + "timestamp": datetime.now().isoformat(), + "path": str(cache_path), + "avg_reward": metadata.get("avg_reward"), + "num_examples": n, + "rollouts_per_example": rollouts_per_example, + } + + with open(index_file, "w") as f: + json.dump(index_data, f, indent=2) + if save_to_hf_hub: if hf_hub_dataset_name == "": dataset_name = ( @@ -307,6 +347,20 @@ def main(): action="store_true", help="Save dataset to disk", ) + parser.add_argument( + "--save-dataset-cache", + "-sc", + default=False, + action="store_true", + help="Save dataset to disk .cache", + ) + parser.add_argument( + "--cache-dir", + "-C", + type=str, + default="", + help="Custom cache directory" + ) parser.add_argument( "--save-to-hf-hub", "-H", @@ -339,6 +393,8 @@ def main(): sampling_args=args.sampling_args, verbose=args.verbose, save_dataset=args.save_dataset, + save_dataset_cache=args.save_dataset_cache, + cache_dir=args.cache_dir, save_to_hf_hub=args.save_to_hf_hub, hf_hub_dataset_name=args.hf_hub_dataset_name, )