diff --git a/conf/base.yaml b/conf/base.yaml index 6be2b593..c29092f5 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -15,6 +15,7 @@ actor: llm_max_rollouts: 64 rollout_workers: 1 discount_factor: 1 + pause_training_during_eval: true problem_queue_size: 64 result_queue_size: 64 throughput_window_size: 50 @@ -140,4 +141,3 @@ wandb: wandb_dir: null # Comma-separated list of keywords to tag the run. tags: [] - diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 358b3797..7a152a4b 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -7,11 +7,11 @@ import queue import random import time -from collections import defaultdict +from collections import Counter, defaultdict from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional import aiohttp import hydra @@ -21,7 +21,6 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM - import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.rollouts import BaseMetrics, RolloutResult @@ -51,6 +50,22 @@ def save_debug_line(data:dict): with open(fname, "a") as f: f.write(json.dumps(data, ensure_ascii=False) + "\n") +def get_number_of_tokens_in_result(result: RolloutResult) -> int: + """Aggregate prompt + output tokens for all training texts in a rollout result.""" + total_tokens = 0 + for training_text in result.training_texts: + prompt_tokens = getattr(training_text, "prompt_tokens", 0) or 0 + output_tokens = getattr(training_text, "output_tokens", 0) or 0 + if prompt_tokens or output_tokens: + total_tokens += prompt_tokens + output_tokens + continue + input_ids = getattr(training_text, "input_ids", None) + if input_ids: + total_tokens += len(input_ids) + continue + total_tokens += getattr(training_text, "n_predicted", 0) or 0 + return total_tokens + class SlidingWindowData(BaseModel): prompt_tokens_window: list[list[int]] = Field( default_factory=list, @@ -119,8 +134,124 @@ def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) -def get_number_of_tokens_in_result(result: RolloutResult) -> int: - return sum(training_text.prompt_tokens + training_text.output_tokens for training_text in result.training_texts) +def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: + raw_schedule = curriculum_cfg.get("schedule", []) + if not raw_schedule: + return [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] + if not isinstance(raw_schedule, list): + raw_schedule = [raw_schedule] + parsed_schedule: list[dict] = [] + for entry in raw_schedule: + step = int(entry.get("step", 0)) + if step < 0: + step = 0 + hard_weight = float(entry.get("hard_weight", 0.0)) + hard_weight = max(0.0, min(1.0, hard_weight)) + medium_weight_value = entry.get("medium_weight", 0.0) + try: + medium_weight = float(medium_weight_value) + except (TypeError, ValueError): + medium_weight = 0.0 + medium_weight = max(0.0, min(1.0, medium_weight)) + weight_sum = medium_weight + hard_weight + if weight_sum > 1.0 and weight_sum > 0.0: + scale = 1.0 / weight_sum + medium_weight *= scale + hard_weight *= scale + ready_success_cfg = entry.get("ready_success") or [] + if not isinstance(ready_success_cfg, list): + ready_success_cfg = [ready_success_cfg] + ready_success: list[dict] = [] + for cond in ready_success_cfg: + if not isinstance(cond, dict): + continue + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + try: + threshold = float(cond.get("threshold", 1.0)) + except (TypeError, ValueError): + threshold = 1.0 + ready_success.append( + { + "dataset": dataset, + "metric": metric, + "threshold": threshold, + } + ) + patience_value = entry.get("ready_patience", 1) + try: + ready_patience = max(1, int(patience_value)) + except (TypeError, ValueError): + ready_patience = 1 + parsed_schedule.append( + { + "step": step, + "medium_weight": medium_weight, + "hard_weight": hard_weight, + "ready_success": ready_success, + "ready_patience": ready_patience, + } + ) + parsed_schedule.sort(key=lambda item: item["step"]) + return parsed_schedule + + +def advance_curriculum_stage( + schedule: list[dict], + stage_state: dict, + samples_processed: int, + stats: dict, + logger: logging.Logger | None = None, +) -> None: + if not schedule or stage_state is None: + return + current_idx = int(stage_state.get("index", 0)) + ready_counts = stage_state.setdefault("ready_counts", {}) + advanced = False + + while current_idx + 1 < len(schedule): + next_idx = current_idx + 1 + stage_cfg = schedule[next_idx] + min_step = stage_cfg.get("step", 0) + if samples_processed < min_step: + break + + ready_conditions: list[dict] = stage_cfg.get("ready_success") or [] + patience = max(1, int(stage_cfg.get("ready_patience", 1))) + + if ready_conditions: + all_pass = True + for cond in ready_conditions: + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + threshold = float(cond.get("threshold", 1.0)) + if dataset: + metric_key = f"{dataset}/{metric}" + else: + metric_key = metric + value = stats.get(metric_key) + if value is None or value < threshold: + all_pass = False + break + if all_pass: + ready_counts[next_idx] = ready_counts.get(next_idx, 0) + 1 + else: + ready_counts[next_idx] = 0 + break + if ready_counts[next_idx] < patience: + break + current_idx = next_idx + stage_state["index"] = current_idx + ready_counts.pop(next_idx, None) + advanced = True + if logger: + logger.info( + "Curriculum stage %d activated (samples_processed=%d)", + current_idx, + samples_processed, + ) + if not advanced: + stage_state["index"] = current_idx async def schedule_rollouts( @@ -166,9 +297,9 @@ async def rollout_and_maybe_produce_result( llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - logger.info(f"Starting rollout policy for problem {problem['id']}") + logger.debug(f"Starting rollout policy for problem {problem['id']}") rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session) - logger.info(f"Finished rollout policy for problem {problem['id']}") + logger.debug(f"Finished rollout policy for problem {problem['id']}") rollout_result.model_version = model_version token_count += get_number_of_tokens_in_result(rollout_result) # Make a group id that will be different from groups made by another rollout maker @@ -215,7 +346,7 @@ async def rollout_and_maybe_produce_result( if finished_rollouts > old_finished_rollouts: old_finished_rollouts = finished_rollouts save_debug_line({"rollouts_finished": finished_rollouts, "tokens_produced": token_count, "dt": time.time() - start_time, "token_speed": token_count / (time.time() - start_time)}) - logger.info( + logger.debug( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " f"groups in progress: {len(group_rollouts)}, " @@ -286,6 +417,113 @@ def random_iter(problems: list): yield random.sample(problems, 1)[0] +def curriculum_iter( + problems: list, + trainer_state: TrainerState, + curriculum_cfg: DictConfig, + logger: logging.Logger | None = None, + parsed_schedule: Optional[list[dict]] = None, + stage_state: Optional[dict] = None, +): + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) + base_names = set(curriculum_obj.get("base_datasets", [])) + medium_names = set(curriculum_obj.get("medium_datasets", [])) + hard_names = set(curriculum_obj.get("hard_datasets", [])) + if hard_names and not base_names: + base_names = { + problem.get("dataset") + for problem in problems + if problem.get("dataset") not in hard_names and problem.get("dataset") not in medium_names + } + + base_pool = [ + problem + for problem in problems + if (problem.get("dataset") in base_names) + or ( + not base_names + and problem.get("dataset") not in hard_names + and (not medium_names or problem.get("dataset") not in medium_names) + ) + ] + medium_pool = [problem for problem in problems if problem.get("dataset") in medium_names] + hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] + + if not base_pool and medium_pool: + base_pool = list(medium_pool) + medium_pool = [] + + if not base_pool: + if logger: + logger.warning("Curriculum enabled but no matching datasets were found; falling back to random sampling") + yield from random_iter(problems) + return + + schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) + if not schedule: + schedule = [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] + + current_stage_index = stage_state.get("index", 0) if stage_state is not None else 0 + last_logged_stage: Optional[int] = None + + while True: + samples_processed = trainer_state.samples_processed or 0 + max_stage_allowed = current_stage_index + while ( + max_stage_allowed + 1 < len(schedule) + and samples_processed >= schedule[max_stage_allowed + 1]["step"] + ): + max_stage_allowed += 1 + + if stage_state is not None: + desired_index = int(stage_state.get("index", 0)) + current_stage_index = max(0, min(desired_index, max_stage_allowed)) + else: + current_stage_index = max_stage_allowed + + stage_cfg = schedule[current_stage_index] + medium_weight = stage_cfg.get("medium_weight", 0.0) + hard_weight = stage_cfg.get("hard_weight", 0.0) + if not medium_pool: + medium_weight = 0.0 + if not hard_pool: + hard_weight = 0.0 + base_weight = max(0.0, 1.0 - medium_weight - hard_weight) + weight_sum = base_weight + medium_weight + hard_weight + if weight_sum == 0.0: + base_weight = 1.0 + medium_weight = 0.0 + hard_weight = 0.0 + + if logger and last_logged_stage != current_stage_index: + logger.info( + "Curriculum stage %d active (samples_processed=%d, base=%.3f, medium=%.3f, hard=%.3f)", + current_stage_index, + samples_processed, + base_weight, + medium_weight, + hard_weight, + ) + last_logged_stage = current_stage_index + + if stage_state is not None: + stage_state["index"] = current_stage_index + + choice = random.random() + hard_cutoff = hard_weight + medium_cutoff = hard_cutoff + medium_weight + if hard_pool and choice < hard_cutoff: + yield random.choice(hard_pool) + elif medium_pool and choice < medium_cutoff: + yield random.choice(medium_pool) + else: + yield random.choice(base_pool) + + def sequential_iter(problems: list): for problem in problems: yield problem @@ -304,14 +542,19 @@ def __init__( self.data_stream = data_stream self.trainer_state = trainer_state self.stats_stream = stats_stream - self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size) + self.sliding_aggregator = None + if is_training: + self.sliding_aggregator = SlidingWindowAggregator( + window_size=cfg.actor.throughput_window_size + ) self.llms = llms self.loop_start_time = -1 self.cfg: DictConfig = cfg self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) - self.cfg: DictConfig = cfg + self.curriculum_schedule: list[dict] | None = None + self.curriculum_stage_state: dict | None = None self.smm: SharedMemoryManager | None = None self.problem_queue: SharedMemoryQueue | None = None @@ -359,6 +602,8 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) + self.answer_status_counts = Counter() + self.dataset_sample_counts = Counter() def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -390,13 +635,28 @@ def update_stats(self, rollout_results: List[RolloutResult]): else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") - prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts] - output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts] - self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) - sliding_window_stats = self.sliding_aggregator.get_stats() - if sliding_window_stats is not None: - for k, v in sliding_window_stats.items(): - self.sliding_stats[k].append(v) + status = getattr(result, "answer_status", None) + if status in {"correct", "wrong", "unparsable", "no_answer"}: + self.answer_status_counts[status] += 1 + if dataset_name: + self.dataset_sample_counts[str(dataset_name)] += 1 + + if self.sliding_aggregator: + prompt_length_tokens = [ + training_text.prompt_tokens + for result in rollout_results + for training_text in result.training_texts + ] + output_length_tokens = [ + training_text.output_tokens + for result in rollout_results + for training_text in result.training_texts + ] + self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) + sliding_window_stats = self.sliding_aggregator.get_stats() + if sliding_window_stats is not None: + for k, v in sliding_window_stats.items(): + self.sliding_stats[k].append(v) @@ -417,9 +677,35 @@ def run(self, dataset: list[tuple[str, dict]]): # for train sample, sample random batches infinitely # for test samples, loop through the dataset once if self.is_training: - problem_iter = random_iter(dataset) + curriculum_cfg = getattr(self.cfg.actor, "curriculum", None) + use_curriculum = bool( + curriculum_cfg and getattr(curriculum_cfg, "enabled", False) and dataset + ) + if use_curriculum: + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) + parsed_schedule = parse_curriculum_schedule(curriculum_obj) + self.curriculum_schedule = parsed_schedule + self.curriculum_stage_state = {"index": 0} + problem_iter = curriculum_iter( + dataset, + trainer_state=self.trainer_state, + curriculum_cfg=curriculum_cfg, + logger=logger, + parsed_schedule=parsed_schedule, + stage_state=self.curriculum_stage_state, + ) + else: + problem_iter = random_iter(dataset) + self.curriculum_schedule = None + self.curriculum_stage_state = None else: problem_iter = sequential_iter(dataset) + self.curriculum_schedule = None + self.curriculum_stage_state = None assert self.trainer_state.propagated_weight_version is not None last_trainer_version = self.trainer_state.propagated_weight_version @@ -585,22 +871,27 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 - - rename_suffixes = { - "num_python_calls_mean": "python_calls_mean", - "used_python_mean": "python_usage_rate", - "num_math_answer_calls_mean": "math_answer_calls_mean", - "used_math_answer_mean": "math_answer_usage_rate", - } - - for key in list(stats.keys()): - for old_suffix, new_suffix in rename_suffixes.items(): - if key.endswith(old_suffix): - prefix = key[: -len(old_suffix)] - stats[f"{prefix}{new_suffix}"] = stats[key] - break - - logger.info(f"Publish actor stats to wandb: {stats}") + if self.curriculum_schedule and self.curriculum_stage_state is not None: + advance_curriculum_stage( + self.curriculum_schedule, + self.curriculum_stage_state, + self.trainer_state.samples_processed or 0, + stats, + logger, + ) + stats["curriculum_stage_active"] = self.curriculum_stage_state.get("index", 0) + + total_status = sum(self.answer_status_counts.values()) + if total_status: + for status, count in self.answer_status_counts.items(): + stats[f"{split_name}answer_status_{status}_count"] = count + stats[f"{split_name}answer_status_{status}_ratio"] = count / total_status + total_rollouts = sum(self.dataset_sample_counts.values()) + if total_rollouts: + stats["dataset_rollouts_total"] = total_rollouts + for dataset, count in self.dataset_sample_counts.items(): + stats[f"{dataset}/rollout_count"] = count + stats[f"{dataset}/rollout_ratio"] = count / total_rollouts if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) @@ -869,6 +1160,9 @@ def run_actor_loop(cfg: DictConfig): ) test_loop_run = None + pause_training_during_eval = bool( + getattr(cfg.actor, "pause_training_during_eval", True) + ) last_regular_eval = -1 current_eval = -1 while True: @@ -892,7 +1186,8 @@ def run_actor_loop(cfg: DictConfig): test_loop_run = test_loop.run( dataset=test_dataset, ) - train_loop.is_scheduling_paused = True + if pause_training_during_eval: + train_loop.is_scheduling_paused = True current_eval = next_regular_eval # 2. If there is an active test loop, keep it running @@ -903,7 +1198,8 @@ def run_actor_loop(cfg: DictConfig): # 2.1 If the test loop is finished, resume scheduling the training loop test_loop_run = None last_regular_eval = current_eval - train_loop.is_scheduling_paused = False + if pause_training_during_eval: + train_loop.is_scheduling_paused = False logger.info("Test loop finished") # 3. Keep running the training loop diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 2eec4a9b..fbfaea97 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -2,7 +2,8 @@ import logging import random import re -from typing import Dict, List, Tuple +from pathlib import Path +from typing import Dict, Iterable, List, Sequence, Tuple import datasets import hydra @@ -190,26 +191,6 @@ def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]: return add_ids(samples) -def _load_aime_2025_opencompass(upsample_factor: int = 0) -> list[dict]: - configs = ["AIME2025-I", "AIME2025-II"] - dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original") - - samples: list[dict] = [] - for config_name in configs: - ds = load_dataset("opencompass/AIME2025", config_name, split="test") - samples.extend([s for s in process_math(ds, dataset_name) if s is not None]) - - original_size = len(samples) - if upsample_factor > 0: - samples *= upsample_factor - - logger.info( - f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples" - + (f" (upsampled from {original_size})" if upsample_factor > 0 else "") - ) - return add_ids(samples) - - def _load_amc_dataset(year: int, upsample_factor: int = 0) -> list[dict]: amc_dataset = load_dataset("AI-MO/aimo-validation-amc", split="train", trust_remote_code=True) amc_dataset = amc_dataset.filter(lambda x: str(year) in x["url"]) @@ -234,18 +215,93 @@ def add_ids(dataset: list[dict]): return dataset +def _resolve_custom_path(relative_paths: str | Sequence[str]) -> Path: + """ + Resolve a path for locally generated datasets. + + Hydra jobs may change the working directory, so we check both the current + directory and the repository root. + """ + if isinstance(relative_paths, str): + relative_paths = [relative_paths] + + resolved = Path(__file__).resolve() + base_candidates = [Path.cwd()] + if len(resolved.parents) >= 5: + base_candidates.append(resolved.parents[4]) + + candidates: List[Path] = [] + for rel in relative_paths: + rel_path = Path(rel) + candidates.append(rel_path) + for base in base_candidates: + if base == Path.cwd(): + continue + candidates.append(base / rel_path) + + for candidate in candidates: + if candidate.exists(): + return candidate + raise FileNotFoundError( + f"Custom dataset not found. Tried: {[str(path) for path in candidates]}" + ) + + +def _load_custom_dataset(dataset_name: str) -> list[dict]: + """ + Load a locally generated dataset by name. + + The loader searches under `datasets/custom/` and `datasets/custom_runs/` for either + `` or `.jsonl`. + """ + candidate_names: List[str] = [] + if dataset_name.endswith(".jsonl"): + candidate_names.append(dataset_name) + else: + candidate_names.extend([dataset_name, f"{dataset_name}.jsonl"]) + + search_paths: List[str] = [] + for name in candidate_names: + search_paths.extend( + [ + f"datasets/custom/{name}", + f"datasets/custom_runs/{name}", + name, + ] + ) + + dataset_path = _resolve_custom_path(search_paths) + with dataset_path.open("r", encoding="utf-8") as handle: + samples = [json.loads(line) for line in handle if line.strip()] + + dataset_label = dataset_name[:-6] if dataset_name.endswith(".jsonl") else dataset_name + + for idx, sample in enumerate(samples): + sample.setdefault("source_dataset", sample.get("dataset", dataset_label)) + sample.setdefault("source_id", sample.get("id")) + sample["dataset"] = dataset_label + sample["id"] = idx + + logger.info(f"Loading custom dataset {dataset_name}: {len(samples)} samples from {dataset_path}") + return samples + + def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None) -> List[Tuple[str, Dict]]: if dataset_names is None: return [] if isinstance(dataset_names, str): dataset_names = [dataset_names] + # Preserve order while de-duplicating + dataset_names = list(dict.fromkeys(dataset_names)) datasets = [] + remaining = set(dataset_names) if "eurus_train" in dataset_names: dataset = load_dataset("PRIME-RL/Eurus-2-RL-Data", split="train", trust_remote_code=True) samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_train") # great for debugging since its much smaller than eurus train if "eurus_validation" in dataset_names: @@ -253,6 +309,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_eurus(dataset) if s is not None] logger.info(f"Loading eurus validation dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("eurus_validation") if "math_train" in dataset_names: # math_dataset = load_math("train") @@ -260,6 +317,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_train") if s is not None] logger.info(f"Loading math train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_train") if "math_simplerl_train" in dataset_names: # SimpleRL MATH dataset @@ -274,6 +332,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_simplerl_train") if s is not None] logger.info(f"Loading math simplerl train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_simplerl_train") if "simplerl_math_subset_1000" in dataset_names: # SimpleRL MATH dataset subset @@ -292,12 +351,14 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = samples[:1000] logger.info(f"Loading math simplerl subset test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("simplerl_math_subset_1000") if "deepscaler_preview" in dataset_names: dataset = load_dataset("agentica-org/DeepScaleR-Preview-Dataset", split="train", trust_remote_code=True) samples = [s for s in process_math(dataset, "deepscaler") if s is not None] logger.info(f"Loading deepscaler preview train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("deepscaler_preview") if "math_test" in dataset_names: # math_dataset = load_math("test") @@ -305,36 +366,42 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_math(dataset, "math_test") if s is not None] logger.info(f"Loading math test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_test") if "omni_math_500" in dataset_names: dataset = load_dataset("reliable-agents/Omni-MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "omni_math_500") if s is not None] logger.info(f"Loading omni math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("omni_math_500") if "math_500" in dataset_names: dataset = load_dataset("HuggingFaceH4/MATH-500", split="test", trust_remote_code=True) samples = [s for s in process_math(dataset, "math_500") if s is not None] logger.info(f"Loading math 500 dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("math_500") if "open_r1_math_220k" in dataset_names: dataset = load_dataset("open-r1/OpenR1-Math-220k", split="default", trust_remote_code=True) samples = [s for s in process_math(dataset, "open_r1_math_220k") if s is not None] logger.info(f"Loading open r1 math 220k dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_r1_math_220k") if "gpqa_main" in dataset_names: dataset = load_dataset("hendrydong/gpqa_main", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_main") if s is not None] logger.info(f"Loading gpqa main dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_main") if "gpqa_diamond" in dataset_names: dataset = load_dataset("hendrydong/gpqa_diamond", split="test", trust_remote_code=True) samples = [s for s in process_gpqa(dataset, "gpqa_diamond") if s is not None] logger.info(f"Loading gpqa diamond dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gpqa_diamond") if "gpqa_diamond" in dataset_names: pass @@ -344,55 +411,70 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_gsm8k(dataset, "gsm8k_train") if s is not None] logger.info(f"Loading gsm8k train dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_train") if "gsm8k_test" in dataset_names: dataset = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True) samples = [s for s in process_gsm8k(dataset, "gsm8k_test") if s is not None] logger.info(f"Loading gsm8k test dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("gsm8k_test") if "limo" in dataset_names: dataset = load_dataset("GAIR/LIMO", split="train", trust_remote_code=True) samples = [s for s in process_limo(dataset) if s is not None] logger.info(f"Loading limo dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("limo") if "aime_2022" in dataset_names: datasets += _load_aime_dataset(2022, upsample_factor=16) + remaining.discard("aime_2022") if "aime_2022_original" in dataset_names: datasets += _load_aime_dataset(2022) + remaining.discard("aime_2022_original") if "aime_2023" in dataset_names: datasets += _load_aime_dataset(2023, upsample_factor=16) + remaining.discard("aime_2023") if "aime_2023_original" in dataset_names: datasets += _load_aime_dataset(2023) + remaining.discard("aime_2023_original") if "aime_2024" in dataset_names: datasets += _load_aime_dataset(2024, upsample_factor=16) + remaining.discard("aime_2024") if "aime_2024_original" in dataset_names: datasets += _load_aime_dataset(2024) + remaining.discard("aime_2024_original") if "aime_2025" in dataset_names: datasets += _load_aime_2025_opencompass_dataset(upsample_factor=16) + remaining.discard("aime_2025") if "aime_2025_original" in dataset_names: datasets += _load_aime_2025_opencompass_dataset() + remaining.discard("aime_2025_original") if "amc_2022" in dataset_names: # TODO: AMC 2022 is 43 problems, is that to be expected? datasets += _load_amc_dataset(2022, upsample_factor=16) + remaining.discard("amc_2022") if "amc_2022_original" in dataset_names: datasets += _load_amc_dataset(2022) + remaining.discard("amc_2022_original") if "amc_2023" in dataset_names: datasets += _load_amc_dataset(2023, upsample_factor=16) + remaining.discard("amc_2023") if "amc_2023_original" in dataset_names: datasets += _load_amc_dataset(2023) + remaining.discard("amc_2023_original") if "sometimes_success_data" in dataset_names: PATH = "data/sometimes_success_data/data.jsonl" @@ -400,6 +482,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [json.loads(line) for line in f] logger.info(f"Loading easy data dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("sometimes_success_data") if "open_reasoner_zero_57k" in dataset_names: dataset = load_dataset( @@ -411,6 +494,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_57k") if s is not None] logger.info(f"Loading Open Reasoner Zero dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_57k") if "open_reasoner_zero_extended_72k" in dataset_names: dataset = load_dataset( @@ -422,6 +506,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_extended_72k") if s is not None] logger.info(f"Loading Open Reasoner Zero extended dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_extended_72k") if "open_reasoner_zero_hard_13k" in dataset_names: dataset = load_dataset( @@ -433,6 +518,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_open_reasoner(dataset, "open_reasoner_zero_hard_13k") if s is not None] logger.info(f"Loading Open Reasoner Zero hard dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard("open_reasoner_zero_hard_13k") for dataset_name in dataset_names: test_matched = re.match(r"multiplication_(\d+)_by_(\d+)_(\d+)_test", dataset_name) @@ -453,6 +539,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None ] logger.info(f"Loading multiplication {num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples") datasets += add_ids(samples) + remaining.discard(dataset_name) elif train_matched: upto_prefix = train_matched.group(1) or "" num_digits_1 = int(train_matched.group(2)) @@ -474,6 +561,7 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None f"Loading multiplication {upto_prefix}_{num_digits_1}_by_{num_digits_2} dataset: {len(samples)} samples" ) datasets += add_ids(samples) + remaining.discard(dataset_name) if "countdown" in dataset_names: dataset = load_dataset( @@ -482,6 +570,19 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None samples = [s for s in process_countdown(dataset) if s is not None] logger.info(f"Loading countdown dataset: {len(samples)} samples") datasets += samples + remaining.discard("countdown") + + # resolve any remaining names as local custom datasets. + unresolved: List[str] = [] + for dataset_name in list(remaining): + try: + datasets += _load_custom_dataset(dataset_name) + remaining.discard(dataset_name) + except FileNotFoundError: + unresolved.append(dataset_name) + + if unresolved: + raise ValueError(f"Unknown dataset(s): {unresolved}") if len(datasets) == 0: raise ValueError("No datasets loaded") diff --git a/pipelinerl/domains/math/minimal_rollout.py b/pipelinerl/domains/math/minimal_rollout.py deleted file mode 100644 index bc46d2a1..00000000 --- a/pipelinerl/domains/math/minimal_rollout.py +++ /dev/null @@ -1,72 +0,0 @@ -import time -import random - -import aiohttp -from omegaconf import DictConfig -from pydantic import BaseModel -from pipelinerl.rollouts import RolloutResult, BaseMetrics -from pipelinerl.world import Job -from tapeagents.core import Prompt -from tapeagents.llms.trainable import TrainableLLM - -from pipelinerl.async_llm import llm_async_generate, make_training_text -from .verifier_api import verify_answer_rpc - -class Metrics(BaseMetrics): - pass - -class RewardTable(BaseModel): - wrong_answer_not_finished: float - wrong_answer_finished: float - no_answer_not_finished: float - no_answer_finished: float - unparsable_not_finished: float - unparsable_finished: float - correct_answer_not_finished: float - correct_answer_finished: float - buffer_tokens: int = 0 # 0 means no overlong reward shaping - -def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: - """ - Compute the overlong penalty - """ - if sequence_length > (max_length - buffer_tokens) and sequence_length <= max_length: - return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens - return 0. - -def get_reward(trace, answer_status: str, rewards: RewardTable) -> float: - pass - - -async def generate_math_rollout( - cfg: DictConfig, - llm: TrainableLLM, - problem: dict, - session: aiohttp.ClientSession, -) -> RolloutResult: - messages = [] - if cfg.actor.system_prompt: - messages.append({"role": "system", "content": cfg.actor.system_prompt}) - messages.append({"role": "user", "content": f"{problem['task']} \n{cfg.actor.task_prompt}"}) - prompt = Prompt(messages=messages) - - time_start = time.time() - llm_call = await llm_async_generate(llm, prompt, session) - latency = time.time() - time_start - - assert llm_call.output.content is not None - rewards = RewardTable(**dict(cfg.rewards)) - - env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - env_job = random.choice(env_jobs) - assert env_job.port is not None - answer_status = await verify_answer_rpc(session=session, host=env_job.hostname, port=env_job.port, prediction=llm_call.output.content, gold=problem["answer"]) - - trace = make_training_text(llm, llm_call) - reward = get_reward(trace, answer_status, rewards) - trace.reward = reward - - metrics = Metrics(reward=reward, success=answer_status == "correct", no_error=answer_status != "unparsable", no_answer=answer_status == "no_answer") - - - return RolloutResult(training_texts=[trace], metrics=metrics, latency=latency, dataset_name=problem.get("dataset")) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 4da0b753..09c879a6 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,7 +1,10 @@ +import asyncio +import json +import logging +import random import re import time -import random -import logging +from pathlib import Path import aiohttp from omegaconf import DictConfig @@ -14,6 +17,7 @@ from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc +logger = logging.getLogger(__name__) class Metrics(BaseMetrics): penalty: float @@ -139,6 +143,30 @@ def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens return 0. + +def log_answer_status(cfg: DictConfig, problem: dict, answer_status: str, reward: float, latency: float) -> None: + """ + Metric logging for answer status - correct, wrong, no_answer, unparsable + """ + try: + log_dir = Path(cfg.output_dir) if cfg.output_dir else None + if not log_dir: + return + log_path = log_dir / "answer_status.jsonl" + record = { + "t": time.time(), + "problem_id": problem.get("id"), + "dataset": problem.get("dataset"), + "answer_status": answer_status, + "reward": reward, + "latency": latency, + } + with log_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record)) + handle.write("\n") + except Exception: + logger.debug("Failed to append answer status log", exc_info=True) + async def generate_math_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -155,7 +183,16 @@ async def generate_math_rollout( prompt = Prompt(messages=messages) time_start = time.time() - llm_call = await llm_async_generate(llm, prompt, session) + try: + llm_call = await llm_async_generate(llm, prompt, session) + except (asyncio.TimeoutError, aiohttp.client_exceptions.ServerTimeoutError) as exc: + latency = time.time() - time_start + logger.warning( + "LLM request timed out for problem %s. Skipping sample.", + problem.get("id"), + exc_info=exc, + ) + return create_timeout_rollout_result(cfg, problem, latency) latency = time.time() - time_start assert llm_call.output.content is not None @@ -192,9 +229,14 @@ async def generate_math_rollout( reward *= discount_factor**llm_call.output_length_tokens overlong_penalty = 0 if reward_table.buffer_tokens > 0: - overlong_penalty = length_penalty(llm.parameters['max_tokens'], llm_call.output_length_tokens, rewards.buffer_tokens) + overlong_penalty = length_penalty( + llm.parameters["max_tokens"], + llm_call.output_length_tokens, + reward_table.buffer_tokens, + ) reward += overlong_penalty trace.reward = reward + log_answer_status(cfg, problem, answer_status, reward, latency) # Prefer backend-provided finish reason if available; normalize for comparisons if isinstance(trace.metadata, dict): @@ -245,6 +287,32 @@ async def generate_math_rollout( return RolloutResult( training_texts=[trace], metrics=metrics, - latency=latency, + latency=latency, + dataset_name=problem.get("dataset"), + answer_status=answer_status, + ) + + +def create_timeout_rollout_result( + cfg: DictConfig, + problem: dict, + latency: float, +) -> RolloutResult: + answer_status = "timeout" + metrics = Metrics( + reward=0.0, + success=False, + no_error=False, + no_answer=True, + penalty=0.0, + overflow=False, + auto_boxed=False, + ) + log_answer_status(cfg, problem, answer_status, metrics.reward, latency) + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=latency, dataset_name=problem.get("dataset"), + answer_status=answer_status, ) diff --git a/pipelinerl/domains/math/verifier_api.py b/pipelinerl/domains/math/verifier_api.py index 9fa9b4ef..84548f15 100644 --- a/pipelinerl/domains/math/verifier_api.py +++ b/pipelinerl/domains/math/verifier_api.py @@ -1,28 +1,25 @@ -import time -import requests import asyncio -from concurrent.futures import ProcessPoolExecutor -import aiohttp -import uvicorn import logging +import re import signal +import time +from concurrent.futures import ProcessPoolExecutor from contextlib import contextmanager +from functools import partial -from omegaconf import DictConfig -import math_verify # Ensure math_verify is installed - +import aiohttp +import math_verify +import requests # noqa: F401 - retained for parity with upstream +import uvicorn from fastapi import FastAPI from fastapi.responses import JSONResponse -from functools import partial import pipelinerl.countdown_utils logging.basicConfig( - level=logging.DEBUG, # Or INFO, WARNING, etc. + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - logging.StreamHandler(), # Logs to console - ], + handlers=[logging.StreamHandler()], ) @@ -46,93 +43,257 @@ class EmptyBoxedException(Exception): @contextmanager -def timeout(seconds=1): +def timeout(seconds: int = 1): def timeout_handler(signum, frame): raise TimeoutException("Computation timed out") - # Set the timeout handler original_handler = signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(seconds) try: - yield # This is the key addition - context managers must yield + yield finally: - # Restore the original handler and disable the alarm signal.alarm(0) signal.signal(signal.SIGALRM, original_handler) -def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: - """ - Checks if a predicted answer matches a gold (correct) answer by making a request to the math_verify package. - - Args: - prediction (str): The predicted answer to validate - gold (str): The gold (correct) answer to compare against - strict (bool): Whether to enforce strict comparison mode. - - In strict mode: Variables matter and sets are not comparable with tuples - - In non-strict mode: Variables are matched by position and sets can be compared with tuples - url (str): URL of the validation service endpoint - - Returns: - str: The status of the answer, which can be one of the following: - - "correct": The prediction is correct - - "wrong": The prediction is incorrect - - "no_answer": The prediction is empty - - "unparsable": The prediction cannot be parsed - - """ - if prediction.startswith("countdown"): - return verify_countdown(prediction, gold) - else: - return verify_math(prediction, gold, strict=strict, max_prediction_length=max_prediction_length) +ANSWER_PREFIX_RE = re.compile( + r"^(final answer|answer|ans\.?|thus.*?is|therefore.*?is|so the answer is)[:=\-\s]*", + re.IGNORECASE, +) -def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: - import re +def _strip_answer_prefix(line: str) -> str: + return ANSWER_PREFIX_RE.sub("", line).strip() + + +def _extract_fallback_expression(text: str) -> str | None: + if not text: + return None + for raw_line in reversed(text.strip().splitlines()): + cleaned = _strip_answer_prefix(raw_line.strip()).rstrip(".;!") + if cleaned and (any(ch.isdigit() for ch in cleaned) or "\\" in cleaned): + return cleaned + return None + + +def remove_boxed(s: str) -> str: + if "\\boxed " in s: + left = "\\boxed " + if not s.startswith(left): + raise UnparsableException() + return s[len(left) :] + + left = "\\boxed{" + if not s.startswith(left) or not s.endswith("}"): + raise UnparsableException() + return s[len(left) : -1] + + +def last_boxed_only_string(text: str) -> str | None: + idx = text.rfind("\\boxed") + if "\\boxed " in text: + return "\\boxed " + text.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = text.rfind("\\fbox") + if idx < 0: + return None + + right_brace_idx = None + num_left_braces_open = 0 + i = idx + while i < len(text): + if text[i] == "{": + num_left_braces_open += 1 + if text[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + if right_brace_idx is None: + return None + return text[idx : right_brace_idx + 1] + + +def fix_fracs(string: str) -> str: + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + if len(substr) < 2: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + return new_str + + +def fix_a_slash_b(string: str) -> str: + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + return f"\\frac{{{a}}}{{{b}}}" + except (AssertionError, ValueError): + return string + + +def remove_right_units(string: str) -> str: + if "\\text{ " in string: + splits = string.split("\\text{ ") + if len(splits) == 2: + return splits[0] + return string + + +def fix_sqrt(string: str) -> str: + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split and split[0] != "{": + a = split[0] + new_substr = f"\\sqrt{{{a}}}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + +def strip_string(string: str) -> str: + string = string.replace("\n", "").replace("\\!", "").replace("\\\\", "\\") + string = string.replace("tfrac", "frac").replace("dfrac", "frac") + string = string.replace("\\left", "").replace("\\right", "") + string = string.replace("^{\\circ}", "").replace("^\\circ", "") + string = string.replace("\\$", "") + string = remove_right_units(string) + string = string.replace("\\%", "") + string = string.replace(" .", " 0.").replace("{.", "{0.") + if not string: + return string + if string[0] == ".": + string = "0" + string + if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + string = fix_sqrt(string) + string = string.replace(" ", "") + string = fix_fracs(string) + if string == "0.5": + string = "\\frac{1}{2}" + string = fix_a_slash_b(string) + return string + + +def is_equiv(str1: str, str2: str) -> bool: + if str1 is None and str2 is None: + return True + if str1 is None or str2 is None: + return False + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 + +def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: try: - # Input Sanitization / Validation if not isinstance(prediction, str) or not isinstance(gold, str): raise ValueError("Prediction and gold must be strings") - # Try extracting from \boxed{...} first - boxed_start = prediction.rfind("\\boxed{") + prediction = prediction.strip() + if not prediction: + raise NoAnswerException() - if boxed_start >= 0: - boxed_prediction = prediction[boxed_start:] - if "\\boxed{}" in boxed_prediction: - raise EmptyBoxedException() - if len(boxed_prediction) > max_prediction_length: - raise UnparsableException() - extracted_prediction = boxed_prediction - else: - # Fallback: look for ... tags + extracted_prediction: str | None = None + + boxed_prediction = last_boxed_only_string(prediction) + if boxed_prediction is not None: + try: + extracted_prediction = remove_boxed(boxed_prediction).strip() + except UnparsableException as exc: + logger.debug("Failed to remove boxed expression", exc_info=exc) + extracted_prediction = None + + if not extracted_prediction: answer_match = re.findall(r"(.*?)", prediction, re.DOTALL) if answer_match: - extracted_prediction = answer_match[-1].strip() # last one if multiple + extracted_prediction = answer_match[-1].strip() else: - raise NoAnswerException() + fallback_expression = _extract_fallback_expression(prediction) + if fallback_expression: + extracted_prediction = fallback_expression.strip() + else: + raise NoAnswerException() + + if not extracted_prediction: + raise EmptyBoxedException() + + if 0 < max_prediction_length < len(extracted_prediction): + raise UnparsableException() + + if is_equiv(gold, extracted_prediction): + return "correct" + + try: + target_boxed = last_boxed_only_string(f"\\boxed{{{gold}}}") or f"\\boxed{{{gold}}}" + pred_boxed = last_boxed_only_string(f"\\boxed{{{extracted_prediction}}}") or f"\\boxed{{{extracted_prediction}}}" + gold_parsed = math_verify.parse(target_boxed) + pred_parsed = math_verify.parse(pred_boxed) + except Exception as parse_exc: + logger.debug("math_verify.parse failed", exc_info=parse_exc) + raise UnparsableException() from parse_exc - # Parse and verify - gold_parsed = math_verify.parse(gold) - pred_parsed = math_verify.parse(extracted_prediction) if not pred_parsed: - raise ValueError("Failed to parse prediction.") + raise UnparsableException("Prediction parsed to empty result.") - with timeout(1): - equivalent = math_verify.verify(gold_parsed, pred_parsed, strict=strict, timeout_seconds=1) + try: + with timeout(1): + equivalent = math_verify.verify(gold_parsed, pred_parsed, strict=strict, timeout_seconds=1) + except TimeoutException as timeout_exc: + logger.debug("math_verify.verify timed out; treating as wrong", exc_info=timeout_exc) + return "wrong" + except (ValueError, TypeError, NotImplementedError) as verify_exc: + logger.debug("math_verify.verify raised recoverable error; treating as wrong", exc_info=verify_exc) + return "wrong" + except Exception as verify_exc: + logger.debug("math_verify.verify failed unexpectedly", exc_info=verify_exc) + raise - answer_status = "correct" if equivalent else "wrong" + return "correct" if equivalent else "wrong" - except Exception as e: - match e: + except Exception as error: + match error: case NoAnswerException(): answer_status = "no_answer" + case (EmptyBoxedException() | UnparsableException()): + answer_status = "unparsable" case _: + logger.debug("Unexpected verifier error", exc_info=error) answer_status = "unparsable" - - return answer_status - + return answer_status def verify_countdown(prediction: str, gold: str) -> str: @@ -144,28 +305,35 @@ def verify_countdown(prediction: str, gold: str) -> str: if equation is None: return "no_answer" - format_correct = pipelinerl.countdown_utils.validate_format(prediction) - if not format_correct: + if not pipelinerl.countdown_utils.validate_format(prediction): return "unparsable" - # Validate equation uses correct numbers if not pipelinerl.countdown_utils.validate_equation(equation, numbers): return "wrong" - # Evaluate equation try: result = pipelinerl.countdown_utils.evaluate_equation(equation) if result is None: return "wrong" - - if abs(result - target) < 1e-5: # Account for floating point precision - return "correct" - else: - return "wrong" - except Exception as _: + return "correct" if abs(result - target) < 1e-5 else "wrong" + except Exception: return "wrong" +def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: + try: + if prediction.startswith("countdown"): + return verify_countdown(prediction, gold) + return verify_math(prediction, gold, strict=strict, max_prediction_length=max_prediction_length) + except NoAnswerException: + return "no_answer" + except UnparsableException: + return "unparsable" + except Exception as exc: + logger.debug("verify_answer unexpected failure", exc_info=exc) + return "unparsable" + + async def verify_answer_rpc( session: aiohttp.ClientSession, host: str, @@ -175,36 +343,24 @@ async def verify_answer_rpc( strict: bool = True, max_prediction_length: int = 1000, ): - """ - Verify the answer using the verifier API. - """ - json = { + payload = { "prediction": prediction, "gold": gold, "strict": strict, "max_prediction_length": max_prediction_length, } - async with session.post( - f"http://{host}:{port}/verify_answer", - json=json, - ) as response: + async with session.post(f"http://{host}:{port}/verify_answer", json=payload) as response: if response.status == 200: data = await response.json() return data["answer_status"] - else: - logger.error(f"Error verifying answer: {response.status}") - logger.error(f"Response: {await response.text()}") - raise ValueError("Error verifying answer") + logger.error("Error verifying answer: %s", response.status) + logger.error("Response: %s", await response.text()) + raise ValueError("Error verifying answer") class MathEnvironment: - def launch(self, port: int): - """ - Serve the verification API using FastAPI. - """ app = FastAPI() - # Create a process pool with 4 workers with ProcessPoolExecutor(max_workers=4) as process_pool: @app.post("/verify_answer") async def verify(request: dict): @@ -213,7 +369,6 @@ async def verify(request: dict): strict = request["strict"] max_prediction_length = request["max_prediction_length"] - # Run verification in the process pool to avoid blocking the main thread loop = asyncio.get_event_loop() answer_status = await loop.run_in_executor( process_pool, partial(verify_answer, prediction, gold, strict, max_prediction_length) @@ -225,5 +380,3 @@ async def health(): return JSONResponse(content={"status": "ok"}) uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=60) - - diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 9a1e4773..8a57c567 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -22,6 +22,7 @@ from litellm import BaseModel, Field from pipelinerl.finetune.logging_ import flatten_dict_config +from pipelinerl.preprocess_helpers import group_rollout_idx, validate_rollout_group from pipelinerl.shared_memory_array import SharedMemoryArray, SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.utils import setup_logging, wait_for_inference_servers, init_wandb @@ -196,10 +197,27 @@ def run_dataset_loader( buffer = [] n_groups = 0 for group in reader.read(): + if not group: + continue + is_complete, missing, extra = validate_rollout_group(group, check_group_size) + if not is_complete: + group_name = group[0].get("group_id") if group else "" + if not missing and not extra: + logger.warning("Skipping group %s without rollout metadata", group_name) + else: + logger.warning( + "Skipping incomplete group %s: missing rollouts %s extra %s", + group_name, + missing, + extra, + ) + continue buffer.extend(group) n_groups += 1 if n_groups == chunk_n_groups: break + if not buffer: + continue if not _check_group_sizes(buffer, check_group_size): raise ValueError("Invalid group sizes in data") try: diff --git a/pipelinerl/preprocess_helpers.py b/pipelinerl/preprocess_helpers.py new file mode 100644 index 00000000..88ee81a2 --- /dev/null +++ b/pipelinerl/preprocess_helpers.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from typing import Iterable + + +def group_rollout_idx(group: Iterable[dict]) -> set[int] | None: + """Extract rollout idx from a rollout group.""" + rollout_indices: set[int] = set() + for text in group: + metadata = text.get("metadata") + if not isinstance(metadata, dict): + return None + rollout_index = metadata.get("rollout_index") + if rollout_index is None: + return None + rollout_indices.add(rollout_index) + return rollout_indices + + +def validate_rollout_group(group: Iterable[dict], group_size: int) -> tuple[bool, list[int], list[int]]: + """Return whether a group is complete and any missing or extra rollout indices.""" + rollout_indices = group_rollout_idx(group) + if rollout_indices is None: + return False, [], [] + if len(rollout_indices) != group_size: + expected_indices = set(range(group_size)) + if rollout_indices.issubset(expected_indices): + missing = sorted(expected_indices - rollout_indices) + extra: list[int] = [] + else: + missing = sorted(expected_indices - rollout_indices) + extra = sorted(rollout_indices - expected_indices) + return False, missing, extra + return True, [], [] diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index dcb27f2d..6ee2b6c2 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -64,3 +64,4 @@ class RolloutResult(BaseModel): model_version: int | None = None dataset_name: str | None = None group_id: str | None = None + answer_status: str | None = None