From e584b51717e9294e4eb84ed41740cd4e0561759d Mon Sep 17 00:00:00 2001 From: rafapi Date: Thu, 16 Oct 2025 18:30:22 +0000 Subject: [PATCH 01/12] Simple curriculum --- pipelinerl/actor.py | 90 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 9ffabe7e..bd6badce 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -14,7 +14,7 @@ import aiohttp import hydra import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM from typing import Dict, List @@ -260,6 +260,80 @@ def random_iter(problems: list): yield random.sample(problems, 1)[0] +def curriculum_iter( + problems: list, + trainer_state: TrainerState, + curriculum_cfg: DictConfig, + logger: logging.Logger | None = None, +): + curriculum_obj = OmegaConf.to_container(curriculum_cfg, resolve=True) if isinstance(curriculum_cfg, DictConfig) else curriculum_cfg + base_names = set(curriculum_obj.get("base_datasets", [])) + hard_names = set(curriculum_obj.get("hard_datasets", [])) + if hard_names and not base_names: + base_names = {problem.get("dataset") for problem in problems if problem.get("dataset") not in hard_names} + + base_pool = [ + problem + for problem in problems + if (problem.get("dataset") in base_names) or (not base_names and problem.get("dataset") not in hard_names) + ] + hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] + + if not hard_pool: + if logger: + logger.warning( + "Curriculum enabled but no problems matched hard_datasets list; falling back to base sampling" + ) + yield from random_iter(problems) + return + + if not base_pool: + if logger: + logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") + base_pool = hard_pool + + schedule_cfg = curriculum_obj.get("schedule", []) + if not schedule_cfg: + schedule = [(0, 0.0)] + else: + if not isinstance(schedule_cfg, list): + schedule_cfg = [schedule_cfg] + schedule = [] + for entry in schedule_cfg: + step = int(entry.get("step", 0)) + hard_weight = float(entry.get("hard_weight", 0.0)) + hard_weight = max(0.0, min(1.0, hard_weight)) + schedule.append((step, hard_weight)) + schedule.sort(key=lambda item: item[0]) + + current_stage = -1 + + while True: + samples_processed = trainer_state.samples_processed or 0 + hard_weight = schedule[0][1] + stage_index = 0 + for idx, (step, weight) in enumerate(schedule): + if samples_processed >= step: + hard_weight = weight + stage_index = idx + else: + break + + if logger and stage_index != current_stage: + logger.info( + "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", + stage_index, + samples_processed, + hard_weight, + ) + current_stage = stage_index + + if hard_pool and random.random() < hard_weight: + yield random.choice(hard_pool) + else: + yield random.choice(base_pool) + + def sequential_iter(problems: list): for problem in problems: yield problem @@ -384,7 +458,19 @@ def run(self, dataset: list[tuple[str, dict]]): # for train sample, sample random batches infinitely # for test samples, loop through the dataset once if self.is_training: - problem_iter = random_iter(dataset) + curriculum_cfg = getattr(self.cfg.actor, "curriculum", None) + use_curriculum = bool( + curriculum_cfg and getattr(curriculum_cfg, "enabled", False) and dataset + ) + if use_curriculum: + problem_iter = curriculum_iter( + dataset, + trainer_state=self.trainer_state, + curriculum_cfg=curriculum_cfg, + logger=logger, + ) + else: + problem_iter = random_iter(dataset) else: problem_iter = sequential_iter(dataset) assert self.trainer_state.propagated_weight_version is not None From efa0d02d1a5e67974dbaa7c283d175eece502f66 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Oct 2025 17:32:51 +0000 Subject: [PATCH 02/12] Adaptive curriculum schedule --- pipelinerl/actor.py | 216 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 195 insertions(+), 21 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index bd6badce..39ce5bb7 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -7,7 +7,7 @@ from queue import Empty import random import time -from collections import defaultdict +from collections import defaultdict, deque from multiprocessing.managers import SharedMemoryManager from pathlib import Path @@ -17,7 +17,7 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM -from typing import Dict, List +from typing import Dict, List, Optional import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb @@ -103,10 +103,100 @@ def get_stats(self): +class CurriculumSuccessTracker: + def __init__(self) -> None: + self._buffers: dict[str, deque[int]] = {} + self._max_windows: dict[str, int] = {} + self._total_counts: defaultdict[str, int] = defaultdict(int) + + def ensure_window(self, dataset: str, window: int) -> None: + if window <= 0: + window = 1 + current = self._max_windows.get(dataset, 0) + if window <= current: + return + existing = self._buffers.get(dataset, deque(maxlen=window)) + if existing.maxlen != window: + new_buffer = deque(existing, maxlen=window) + else: + new_buffer = existing + self._buffers[dataset] = new_buffer + self._max_windows[dataset] = window + + def update(self, dataset: str, success_values: list[int | bool]) -> None: + if not success_values: + return + buffer = self._buffers.get(dataset) + if buffer is None: + maxlen = self._max_windows.get(dataset, max(1, len(success_values))) + buffer = deque(maxlen=maxlen) + self._buffers[dataset] = buffer + self._max_windows[dataset] = maxlen + for value in success_values: + buffer.append(1 if bool(value) else 0) + self._total_counts[dataset] += 1 + + def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[float]: + buffer = self._buffers.get(dataset) + if buffer is None or not buffer: + return None + if window is None or window <= 0 or window >= len(buffer): + values = list(buffer) + else: + if len(buffer) < window: + return None + values = list(buffer)[-window:] + if not values: + return None + return sum(values) / len(values) + + def total_samples(self, dataset: str) -> int: + return self._total_counts.get(dataset, 0) + + + def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) +def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: + raw_schedule = curriculum_cfg.get("schedule", []) + if not raw_schedule: + return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] + if not isinstance(raw_schedule, list): + raw_schedule = [raw_schedule] + parsed_schedule: list[dict] = [] + for entry in raw_schedule: + step = int(entry.get("step", 0)) + hard_weight = float(entry.get("hard_weight", 0.0)) + hard_weight = max(0.0, min(1.0, hard_weight)) + thresholds_cfg = entry.get("success_thresholds", []) or [] + if not isinstance(thresholds_cfg, list): + thresholds_cfg = [thresholds_cfg] + thresholds: list[dict] = [] + for threshold_entry in thresholds_cfg: + dataset = threshold_entry.get("dataset") + if not dataset: + continue + threshold_value = float(threshold_entry.get("threshold", 1.0)) + window = int(threshold_entry.get("window", threshold_entry.get("window_size", 0) or 1)) + if window <= 0: + window = 1 + min_samples_value = threshold_entry.get("min_samples") + min_samples = int(min_samples_value) if min_samples_value is not None else None + thresholds.append( + { + "dataset": dataset, + "threshold": threshold_value, + "window": window, + "min_samples": min_samples, + } + ) + parsed_schedule.append({"step": step, "hard_weight": hard_weight, "thresholds": thresholds}) + parsed_schedule.sort(key=lambda item: item["step"]) + return parsed_schedule + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -265,8 +355,15 @@ def curriculum_iter( trainer_state: TrainerState, curriculum_cfg: DictConfig, logger: logging.Logger | None = None, + success_tracker: CurriculumSuccessTracker | None = None, + stage_state: Optional[dict] = None, + parsed_schedule: Optional[list[dict]] = None, ): - curriculum_obj = OmegaConf.to_container(curriculum_cfg, resolve=True) if isinstance(curriculum_cfg, DictConfig) else curriculum_cfg + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) base_names = set(curriculum_obj.get("base_datasets", [])) hard_names = set(curriculum_obj.get("hard_datasets", [])) if hard_names and not base_names: @@ -292,33 +389,82 @@ def curriculum_iter( logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") base_pool = hard_pool - schedule_cfg = curriculum_obj.get("schedule", []) - if not schedule_cfg: - schedule = [(0, 0.0)] - else: - if not isinstance(schedule_cfg, list): - schedule_cfg = [schedule_cfg] - schedule = [] - for entry in schedule_cfg: - step = int(entry.get("step", 0)) - hard_weight = float(entry.get("hard_weight", 0.0)) - hard_weight = max(0.0, min(1.0, hard_weight)) - schedule.append((step, hard_weight)) - schedule.sort(key=lambda item: item[0]) + schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) + if success_tracker: + for stage in schedule: + for threshold in stage["thresholds"]: + success_tracker.ensure_window(threshold["dataset"], threshold["window"]) + + def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: + if not stage_cfg["thresholds"] or success_tracker is None: + return True, [] + blockers: list[str] = [] + for threshold in stage_cfg["thresholds"]: + dataset = threshold["dataset"] + threshold_value = threshold["threshold"] + window = threshold["window"] + min_samples = threshold.get("min_samples") + if min_samples is not None: + total_samples = success_tracker.total_samples(dataset) + if total_samples < min_samples: + blockers.append( + f"{dataset}: waiting for {min_samples} samples (have {total_samples})" + ) + continue + success_mean_value = success_tracker.success_mean(dataset, window) + if success_mean_value is None: + blockers.append(f"{dataset}: insufficient window data (need {window})") + continue + if success_mean_value < threshold_value: + blockers.append( + f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" + ) + return (len(blockers) == 0), blockers current_stage = -1 + last_block_log: tuple[int, tuple[str, ...]] | None = None + if stage_state is None: + stage_state = {"index": 0} while True: samples_processed = trainer_state.samples_processed or 0 - hard_weight = schedule[0][1] - stage_index = 0 - for idx, (step, weight) in enumerate(schedule): + desired_stage_index = 0 + hard_weight = schedule[0]["hard_weight"] + + for idx, stage_cfg in enumerate(schedule): + step = stage_cfg["step"] if samples_processed >= step: - hard_weight = weight - stage_index = idx + desired_stage_index = idx + hard_weight = stage_cfg["hard_weight"] else: break + stage_index = desired_stage_index + blocker_messages: list[str] = [] + while stage_index >= 0: + ready, blockers = stage_ready(schedule[stage_index]) + if ready: + blocker_messages = [] + break + blocker_messages = blockers + stage_index -= 1 + + if stage_index < 0: + stage_index = 0 + hard_weight = schedule[0]["hard_weight"] + else: + hard_weight = schedule[stage_index]["hard_weight"] + + if logger and desired_stage_index != stage_index and blocker_messages: + block_signature = (desired_stage_index, tuple(blocker_messages)) + if block_signature != last_block_log: + logger.info( + "Curriculum stage %d gated by: %s", + desired_stage_index, + "; ".join(blocker_messages), + ) + last_block_log = block_signature + if logger and stage_index != current_stage: logger.info( "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", @@ -328,6 +474,8 @@ def curriculum_iter( ) current_stage = stage_index + stage_state["index"] = stage_index + if hard_pool and random.random() < hard_weight: yield random.choice(hard_pool) else: @@ -359,6 +507,8 @@ def __init__( self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) + self.curriculum_tracker: CurriculumSuccessTracker | None = None + self.curriculum_stage_state: dict | None = None # Determine the number of processes to use num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) @@ -426,8 +576,12 @@ def update_stats(self, rollout_results: List[RolloutResult]): for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v + if k == "success" and self.curriculum_tracker: + self.curriculum_tracker.update(dataset_name, v) elif isinstance(v, float) | isinstance(v, bool) | isinstance(v, int): self.stats[k][dataset_name][group_id].append(v) + if k == "success" and self.curriculum_tracker: + self.curriculum_tracker.update(dataset_name, [v]) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") @@ -463,16 +617,34 @@ def run(self, dataset: list[tuple[str, dict]]): curriculum_cfg and getattr(curriculum_cfg, "enabled", False) and dataset ) if use_curriculum: + curriculum_obj = ( + OmegaConf.to_container(curriculum_cfg, resolve=True) + if isinstance(curriculum_cfg, DictConfig) + else curriculum_cfg + ) + parsed_schedule = parse_curriculum_schedule(curriculum_obj) + self.curriculum_tracker = CurriculumSuccessTracker() + for stage in parsed_schedule: + for threshold in stage["thresholds"]: + self.curriculum_tracker.ensure_window(threshold["dataset"], threshold["window"]) + self.curriculum_stage_state = {"index": 0} problem_iter = curriculum_iter( dataset, trainer_state=self.trainer_state, curriculum_cfg=curriculum_cfg, logger=logger, + success_tracker=self.curriculum_tracker, + stage_state=self.curriculum_stage_state, + parsed_schedule=parsed_schedule, ) else: problem_iter = random_iter(dataset) + self.curriculum_tracker = None + self.curriculum_stage_state = None else: problem_iter = sequential_iter(dataset) + self.curriculum_tracker = None + self.curriculum_stage_state = None assert self.trainer_state.propagated_weight_version is not None last_trainer_version = self.trainer_state.propagated_weight_version @@ -633,6 +805,8 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 + if self.curriculum_stage_state is not None: + stats["curriculum_stage_active"] = self.curriculum_stage_state.get("index", 0) if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) From edb72d5c06bae33526c5db5036a517f9b535dff3 Mon Sep 17 00:00:00 2001 From: rafapi Date: Fri, 17 Oct 2025 19:00:22 +0000 Subject: [PATCH 03/12] Fix staging --- pipelinerl/actor.py | 59 ++++++++++++++++++++++++++++++++------------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 39ce5bb7..7d9100e1 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -429,41 +429,66 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: while True: samples_processed = trainer_state.samples_processed or 0 desired_stage_index = 0 - hard_weight = schedule[0]["hard_weight"] for idx, stage_cfg in enumerate(schedule): step = stage_cfg["step"] if samples_processed >= step: desired_stage_index = idx - hard_weight = stage_cfg["hard_weight"] else: break - stage_index = desired_stage_index - blocker_messages: list[str] = [] - while stage_index >= 0: - ready, blockers = stage_ready(schedule[stage_index]) + current_stage = int(stage_state.get("index", 0)) + if current_stage < 0: + current_stage = 0 + if current_stage >= len(schedule): + current_stage = len(schedule) - 1 + + stage_index = min(current_stage, desired_stage_index) + promotion_blockers: list[str] = [] + + # Walk backwards until the current stage is ready (or we reach stage 0) + while stage_index > 0: + ready, _ = stage_ready(schedule[stage_index]) if ready: - blocker_messages = [] break - blocker_messages = blockers stage_index -= 1 - if stage_index < 0: - stage_index = 0 - hard_weight = schedule[0]["hard_weight"] - else: - hard_weight = schedule[stage_index]["hard_weight"] + ready, current_blockers = stage_ready(schedule[stage_index]) + if not ready and stage_index > 0: + # If even after walking back we are not ready, fall back further until 0 + while stage_index > 0 and not ready: + stage_index -= 1 + ready, current_blockers = stage_ready(schedule[stage_index]) + + # Attempt to promote by at most one stage towards the desired stage + if stage_index < desired_stage_index: + next_index = stage_index + 1 + next_ready, blockers = stage_ready(schedule[next_index]) + if next_ready: + stage_index = next_index + current_blockers = [] + else: + promotion_blockers = blockers + + hard_weight = schedule[stage_index]["hard_weight"] + + blockers_for_log: list[str] = [] + block_stage: int | None = None + if stage_index < desired_stage_index: + blockers_for_log = promotion_blockers or current_blockers + block_stage = stage_index + 1 - if logger and desired_stage_index != stage_index and blocker_messages: - block_signature = (desired_stage_index, tuple(blocker_messages)) + if logger and block_stage is not None and blockers_for_log: + block_signature = (block_stage, tuple(blockers_for_log)) if block_signature != last_block_log: logger.info( "Curriculum stage %d gated by: %s", - desired_stage_index, - "; ".join(blocker_messages), + block_stage, + "; ".join(blockers_for_log), ) last_block_log = block_signature + elif stage_index >= desired_stage_index: + last_block_log = None if logger and stage_index != current_stage: logger.info( From 2438af8264e1822ff3eb0e877ca476df7976ea2c Mon Sep 17 00:00:00 2001 From: rafapi Date: Sat, 18 Oct 2025 18:46:17 +0000 Subject: [PATCH 04/12] Fix overlong_penalty call --- pipelinerl/domains/math/rollouts.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 4da0b753..862375d0 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -192,7 +192,11 @@ async def generate_math_rollout( reward *= discount_factor**llm_call.output_length_tokens overlong_penalty = 0 if reward_table.buffer_tokens > 0: - overlong_penalty = length_penalty(llm.parameters['max_tokens'], llm_call.output_length_tokens, rewards.buffer_tokens) + overlong_penalty = length_penalty( + llm.parameters["max_tokens"], + llm_call.output_length_tokens, + reward_table.buffer_tokens, + ) reward += overlong_penalty trace.reward = reward From 984c28491d11e02c31e45698e73f3f25e6362b60 Mon Sep 17 00:00:00 2001 From: rafapi Date: Sat, 18 Oct 2025 18:49:50 +0000 Subject: [PATCH 05/12] Compute num tokens in result --- pipelinerl/actor.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 93d083de..be5a5b58 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -52,6 +52,22 @@ def save_debug_line(data:dict): with open(fname, "a") as f: f.write(json.dumps(data, ensure_ascii=False) + "\n") +def get_number_of_tokens_in_result(result: RolloutResult) -> int: + """Aggregate prompt + output tokens for all training texts in a rollout result.""" + total_tokens = 0 + for training_text in result.training_texts: + prompt_tokens = getattr(training_text, "prompt_tokens", 0) or 0 + output_tokens = getattr(training_text, "output_tokens", 0) or 0 + if prompt_tokens or output_tokens: + total_tokens += prompt_tokens + output_tokens + continue + input_ids = getattr(training_text, "input_ids", None) + if input_ids: + total_tokens += len(input_ids) + continue + total_tokens += getattr(training_text, "n_predicted", 0) or 0 + return total_tokens + class SlidingWindowData(BaseModel): prompt_tokens_window: list[list[int]] = Field( default_factory=list, @@ -153,12 +169,12 @@ def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[f buffer = self._buffers.get(dataset) if buffer is None or not buffer: return None - if window is None or window <= 0 or window >= len(buffer): + if window is None or window <= 0: values = list(buffer) else: - if len(buffer) < window: - return None values = list(buffer)[-window:] + if len(values) < window: + return None if not values: return None return sum(values) / len(values) From 72baea6f82a9134b4ec186df1141b217af64e3a0 Mon Sep 17 00:00:00 2001 From: rafapi Date: Sun, 19 Oct 2025 11:28:18 +0000 Subject: [PATCH 06/12] Implement sliding stats for smooth harden, just during training --- pipelinerl/actor.py | 147 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 121 insertions(+), 26 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index be5a5b58..64fbcf2c 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -199,6 +199,19 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: step = int(entry.get("step", 0)) hard_weight = float(entry.get("hard_weight", 0.0)) hard_weight = max(0.0, min(1.0, hard_weight)) + medium_weight_value = entry.get("medium_weight", 0.0) + try: + medium_weight = float(medium_weight_value) + except (TypeError, ValueError): + medium_weight = 0.0 + medium_weight = max(0.0, min(1.0, medium_weight)) + demotion_patience_value = entry.get("demotion_patience", 1) + try: + demotion_patience = int(demotion_patience_value) + except (TypeError, ValueError): + demotion_patience = 1 + if demotion_patience < 1: + demotion_patience = 1 thresholds_cfg = entry.get("success_thresholds", []) or [] if not isinstance(thresholds_cfg, list): thresholds_cfg = [thresholds_cfg] @@ -221,7 +234,15 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: "min_samples": min_samples, } ) - parsed_schedule.append({"step": step, "hard_weight": hard_weight, "thresholds": thresholds}) + parsed_schedule.append( + { + "step": step, + "hard_weight": hard_weight, + "medium_weight": medium_weight, + "thresholds": thresholds, + "demotion_patience": demotion_patience, + } + ) parsed_schedule.sort(key=lambda item: item["step"]) return parsed_schedule @@ -404,15 +425,26 @@ def curriculum_iter( else curriculum_cfg ) base_names = set(curriculum_obj.get("base_datasets", [])) + medium_names = set(curriculum_obj.get("medium_datasets", [])) hard_names = set(curriculum_obj.get("hard_datasets", [])) if hard_names and not base_names: - base_names = {problem.get("dataset") for problem in problems if problem.get("dataset") not in hard_names} + base_names = { + problem.get("dataset") + for problem in problems + if problem.get("dataset") not in hard_names and problem.get("dataset") not in medium_names + } base_pool = [ problem for problem in problems - if (problem.get("dataset") in base_names) or (not base_names and problem.get("dataset") not in hard_names) + if (problem.get("dataset") in base_names) + or ( + not base_names + and problem.get("dataset") not in hard_names + and (not medium_names or problem.get("dataset") not in medium_names) + ) ] + medium_pool = [problem for problem in problems if problem.get("dataset") in medium_names] hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] if not hard_pool: @@ -423,10 +455,19 @@ def curriculum_iter( yield from random_iter(problems) return + if medium_names and not medium_pool and logger: + logger.warning( + "Curriculum medium_datasets specified but no problems matched; medium weighting will be ignored" + ) + if not base_pool: if logger: - logger.warning("Curriculum enabled but base pool is empty; sampling exclusively from hard dataset") - base_pool = hard_pool + logger.warning("Curriculum enabled but base pool is empty; falling back to medium or hard datasets") + if medium_pool: + base_pool = list(medium_pool) + medium_pool = [] + else: + base_pool = hard_pool schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) if success_tracker: @@ -434,10 +475,11 @@ def curriculum_iter( for threshold in stage["thresholds"]: success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: + def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: if not stage_cfg["thresholds"] or success_tracker is None: - return True, [] + return True, [], False blockers: list[str] = [] + threshold_blocked = False for threshold in stage_cfg["thresholds"]: dataset = threshold["dataset"] threshold_value = threshold["threshold"] @@ -455,15 +497,17 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: blockers.append(f"{dataset}: insufficient window data (need {window})") continue if success_mean_value < threshold_value: + threshold_blocked = True blockers.append( f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" ) - return (len(blockers) == 0), blockers + return (len(blockers) == 0), blockers, threshold_blocked current_stage = -1 last_block_log: tuple[int, tuple[str, ...]] | None = None if stage_state is None: stage_state = {"index": 0} + stage_state.setdefault("consecutive_failures", {}) while True: samples_processed = trainer_state.samples_processed or 0 @@ -481,41 +525,76 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: current_stage = 0 if current_stage >= len(schedule): current_stage = len(schedule) - 1 + prev_stage = current_stage stage_index = min(current_stage, desired_stage_index) promotion_blockers: list[str] = [] # Walk backwards until the current stage is ready (or we reach stage 0) while stage_index > 0: - ready, _ = stage_ready(schedule[stage_index]) + ready, _, _ = stage_ready(schedule[stage_index]) if ready: break stage_index -= 1 - ready, current_blockers = stage_ready(schedule[stage_index]) + ready, current_blockers, _ = stage_ready(schedule[stage_index]) if not ready and stage_index > 0: # If even after walking back we are not ready, fall back further until 0 while stage_index > 0 and not ready: stage_index -= 1 - ready, current_blockers = stage_ready(schedule[stage_index]) + ready, current_blockers, _ = stage_ready(schedule[stage_index]) # Attempt to promote by at most one stage towards the desired stage if stage_index < desired_stage_index: next_index = stage_index + 1 - next_ready, blockers = stage_ready(schedule[next_index]) + next_ready, blockers, _ = stage_ready(schedule[next_index]) if next_ready: stage_index = next_index current_blockers = [] else: promotion_blockers = blockers - - hard_weight = schedule[stage_index]["hard_weight"] + promotion_block_stage: int | None = None + promotion_blockers_for_log: list[str] = [] + if stage_index < desired_stage_index: + promotion_block_stage = stage_index + 1 + promotion_blockers_for_log = promotion_blockers or current_blockers blockers_for_log: list[str] = [] block_stage: int | None = None - if stage_index < desired_stage_index: - blockers_for_log = promotion_blockers or current_blockers - block_stage = stage_index + 1 + failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) + demotion_cancelled = False + if prev_stage > stage_index: + _, prev_blockers, prev_threshold_blocked = stage_ready(schedule[prev_stage]) + patience = schedule[prev_stage].get("demotion_patience", 1) + if prev_threshold_blocked and patience > 1: + failures = failure_counts.get(prev_stage, 0) + 1 + if failures < patience: + failure_counts[prev_stage] = failures + stage_index = prev_stage + demotion_cancelled = True + block_stage = prev_stage + blockers_for_log = prev_blockers + else: + failure_counts[prev_stage] = 0 + else: + failure_counts[prev_stage] = 0 + else: + failure_counts.setdefault(prev_stage, 0) + failure_counts[prev_stage] = 0 + + if not demotion_cancelled: + failure_counts.setdefault(stage_index, 0) + if stage_index != prev_stage: + failure_counts[stage_index] = 0 + + hard_weight = schedule[stage_index]["hard_weight"] + medium_weight = schedule[stage_index].get("medium_weight", 0.0) + if not medium_pool: + medium_weight = 0.0 + + if block_stage is None and promotion_block_stage is not None: + block_stage = promotion_block_stage + blockers_for_log = promotion_blockers_for_log if logger and block_stage is not None and blockers_for_log: block_signature = (block_stage, tuple(blockers_for_log)) @@ -540,8 +619,11 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str]]: stage_state["index"] = stage_index - if hard_pool and random.random() < hard_weight: + choice = random.random() + if hard_pool and choice < hard_weight: yield random.choice(hard_pool) + elif medium_pool and choice < hard_weight + medium_weight: + yield random.choice(medium_pool) else: yield random.choice(base_pool) @@ -564,7 +646,11 @@ def __init__( self.data_stream = data_stream self.trainer_state = trainer_state self.stats_stream = stats_stream - self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size) + self.sliding_aggregator = None + if is_training: + self.sliding_aggregator = SlidingWindowAggregator( + window_size=cfg.actor.throughput_window_size + ) self.llms = llms self.loop_start_time = -1 self.cfg: DictConfig = cfg @@ -655,13 +741,22 @@ def update_stats(self, rollout_results: List[RolloutResult]): else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") - prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts] - output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts] - self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) - sliding_window_stats = self.sliding_aggregator.get_stats() - if sliding_window_stats is not None: - for k, v in sliding_window_stats.items(): - self.sliding_stats[k].append(v) + if self.sliding_aggregator: + prompt_length_tokens = [ + training_text.prompt_tokens + for result in rollout_results + for training_text in result.training_texts + ] + output_length_tokens = [ + training_text.output_tokens + for result in rollout_results + for training_text in result.training_texts + ] + self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) + sliding_window_stats = self.sliding_aggregator.get_stats() + if sliding_window_stats is not None: + for k, v in sliding_window_stats.items(): + self.sliding_stats[k].append(v) From e7c46d543be916f6aac92457448c3a325ed7a018 Mon Sep 17 00:00:00 2001 From: rafapi Date: Mon, 20 Oct 2025 12:17:10 +0000 Subject: [PATCH 07/12] Improve cooling --- pipelinerl/actor.py | 187 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 174 insertions(+), 13 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 64fbcf2c..e309bc20 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -194,6 +194,22 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] if not isinstance(raw_schedule, list): raw_schedule = [raw_schedule] + default_cooldown = curriculum_cfg.get("default_promotion_cooldown_samples", 8000) + try: + default_cooldown = int(default_cooldown) + except (TypeError, ValueError): + default_cooldown = 8000 + if default_cooldown < 0: + default_cooldown = 0 + + default_hysteresis = curriculum_cfg.get("default_threshold_hysteresis", 0.02) + try: + default_hysteresis = float(default_hysteresis) + except (TypeError, ValueError): + default_hysteresis = 0.02 + if default_hysteresis < 0.0: + default_hysteresis = 0.0 + parsed_schedule: list[dict] = [] for entry in raw_schedule: step = int(entry.get("step", 0)) @@ -205,6 +221,12 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: except (TypeError, ValueError): medium_weight = 0.0 medium_weight = max(0.0, min(1.0, medium_weight)) + weight_sum = medium_weight + hard_weight + max_non_base = 0.85 + if weight_sum > max_non_base and weight_sum > 0: + scale = max_non_base / weight_sum + medium_weight *= scale + hard_weight *= scale demotion_patience_value = entry.get("demotion_patience", 1) try: demotion_patience = int(demotion_patience_value) @@ -212,6 +234,21 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: demotion_patience = 1 if demotion_patience < 1: demotion_patience = 1 + cooldown_value = entry.get( + "promotion_cooldown_samples", entry.get("cooldown_samples", default_cooldown) + ) + try: + promotion_cooldown_samples = int(cooldown_value) + except (TypeError, ValueError): + promotion_cooldown_samples = default_cooldown + if promotion_cooldown_samples < 0: + promotion_cooldown_samples = 0 + hysteresis_value = entry.get("threshold_hysteresis", default_hysteresis) + try: + threshold_hysteresis = float(hysteresis_value) + except (TypeError, ValueError): + threshold_hysteresis = default_hysteresis + threshold_hysteresis = max(0.0, threshold_hysteresis) thresholds_cfg = entry.get("success_thresholds", []) or [] if not isinstance(thresholds_cfg, list): thresholds_cfg = [thresholds_cfg] @@ -241,6 +278,8 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: "medium_weight": medium_weight, "thresholds": thresholds, "demotion_patience": demotion_patience, + "promotion_cooldown_samples": promotion_cooldown_samples, + "threshold_hysteresis": threshold_hysteresis, } ) parsed_schedule.sort(key=lambda item: item["step"]) @@ -475,39 +514,91 @@ def curriculum_iter( for threshold in stage["thresholds"]: success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: + def stage_ready(stage_cfg: dict, relaxation: float = 0.0) -> tuple[bool, list[str], bool, list[dict]]: if not stage_cfg["thresholds"] or success_tracker is None: - return True, [], False + return True, [], False, [] blockers: list[str] = [] threshold_blocked = False + stats: list[dict] = [] for threshold in stage_cfg["thresholds"]: dataset = threshold["dataset"] threshold_value = threshold["threshold"] window = threshold["window"] min_samples = threshold.get("min_samples") + total_samples = success_tracker.total_samples(dataset) if min_samples is not None: - total_samples = success_tracker.total_samples(dataset) if total_samples < min_samples: blockers.append( f"{dataset}: waiting for {min_samples} samples (have {total_samples})" ) + stats.append( + { + "dataset": dataset, + "success_mean": None, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "min_samples", + } + ) continue success_mean_value = success_tracker.success_mean(dataset, window) if success_mean_value is None: blockers.append(f"{dataset}: insufficient window data (need {window})") + stats.append( + { + "dataset": dataset, + "success_mean": None, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "insufficient_window", + } + ) continue - if success_mean_value < threshold_value: + adjusted_threshold = threshold_value - relaxation + if success_mean_value < adjusted_threshold: threshold_blocked = True blockers.append( - f"{dataset}: success_mean {success_mean_value:.3f} < {threshold_value:.3f} (window={window})" + f"{dataset}: success_mean {success_mean_value:.3f} < {adjusted_threshold:.3f} (threshold={threshold_value:.3f}, relaxation={relaxation:.3f}, window={window})" ) - return (len(blockers) == 0), blockers, threshold_blocked + stats.append( + { + "dataset": dataset, + "success_mean": success_mean_value, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "threshold", + } + ) + continue + stats.append( + { + "dataset": dataset, + "success_mean": success_mean_value, + "threshold": threshold_value, + "relaxation": relaxation, + "window": window, + "min_samples": min_samples, + "total_samples": total_samples, + "status": "ok", + } + ) + return (len(blockers) == 0), blockers, threshold_blocked, stats current_stage = -1 last_block_log: tuple[int, tuple[str, ...]] | None = None if stage_state is None: stage_state = {"index": 0} stage_state.setdefault("consecutive_failures", {}) + stage_state.setdefault("last_promotion_samples", -math.inf) while True: samples_processed = trainer_state.samples_processed or 0 @@ -529,42 +620,86 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: stage_index = min(current_stage, desired_stage_index) promotion_blockers: list[str] = [] + promotion_stats_for_log: list[dict] = [] # Walk backwards until the current stage is ready (or we reach stage 0) while stage_index > 0: - ready, _, _ = stage_ready(schedule[stage_index]) + ready, _, _, _ = stage_ready(schedule[stage_index], relaxation=0.0) if ready: break stage_index -= 1 - ready, current_blockers, _ = stage_ready(schedule[stage_index]) + ready, current_blockers, _, current_stats = stage_ready( + schedule[stage_index], relaxation=0.0 + ) if not ready and stage_index > 0: # If even after walking back we are not ready, fall back further until 0 while stage_index > 0 and not ready: stage_index -= 1 - ready, current_blockers, _ = stage_ready(schedule[stage_index]) + ready, current_blockers, _, current_stats = stage_ready( + schedule[stage_index], relaxation=0.0 + ) # Attempt to promote by at most one stage towards the desired stage if stage_index < desired_stage_index: next_index = stage_index + 1 - next_ready, blockers, _ = stage_ready(schedule[next_index]) + next_ready, blockers, _, next_stats = stage_ready( + schedule[next_index], relaxation=0.0 + ) if next_ready: stage_index = next_index current_blockers = [] + current_stats = next_stats else: promotion_blockers = blockers + promotion_stats_for_log = next_stats promotion_block_stage: int | None = None promotion_blockers_for_log: list[str] = [] if stage_index < desired_stage_index: promotion_block_stage = stage_index + 1 promotion_blockers_for_log = promotion_blockers or current_blockers + if promotion_blockers_for_log: + if not promotion_stats_for_log: + promotion_stats_for_log = current_stats + else: + promotion_block_stage = None + + candidate_stage_index = stage_index + if candidate_stage_index > prev_stage + 1: + candidate_stage_index = prev_stage + 1 + + cooldown_blockers: list[str] = [] + cooldown_stats: list[dict] = [] + last_promotion_samples = stage_state.get("last_promotion_samples", -math.inf) + if candidate_stage_index > prev_stage: + cooldown_required = schedule[candidate_stage_index].get( + "promotion_cooldown_samples", 0 + ) + samples_since_promotion = samples_processed - last_promotion_samples + if samples_since_promotion < cooldown_required: + cooldown_blockers = [ + ( + f"promotion cooldown active: {samples_since_promotion} / " + f"{cooldown_required} samples since last promotion" + ) + ] + cooldown_stats = current_stats + candidate_stage_index = prev_stage + else: + stage_state["last_promotion_samples"] = samples_processed + + stage_index = candidate_stage_index blockers_for_log: list[str] = [] block_stage: int | None = None failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) + demotion_stats_for_log: list[dict] = [] demotion_cancelled = False if prev_stage > stage_index: - _, prev_blockers, prev_threshold_blocked = stage_ready(schedule[prev_stage]) + _, prev_blockers, prev_threshold_blocked, prev_stats = stage_ready( + schedule[prev_stage], + relaxation=schedule[prev_stage].get("threshold_hysteresis", 0.0), + ) patience = schedule[prev_stage].get("demotion_patience", 1) if prev_threshold_blocked and patience > 1: failures = failure_counts.get(prev_stage, 0) + 1 @@ -574,6 +709,8 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: demotion_cancelled = True block_stage = prev_stage blockers_for_log = prev_blockers + promotion_stats_for_log = prev_stats + demotion_stats_for_log = prev_stats else: failure_counts[prev_stage] = 0 else: @@ -592,27 +729,51 @@ def stage_ready(stage_cfg: dict) -> tuple[bool, list[str], bool]: if not medium_pool: medium_weight = 0.0 + stats_for_log: list[dict] = [] if block_stage is None and promotion_block_stage is not None: block_stage = promotion_block_stage blockers_for_log = promotion_blockers_for_log + stats_for_log = promotion_stats_for_log + if block_stage is None and cooldown_blockers: + block_stage = prev_stage + 1 if prev_stage + 1 < len(schedule) else prev_stage + blockers_for_log = cooldown_blockers + stats_for_log = cooldown_stats + if not stats_for_log and demotion_stats_for_log: + stats_for_log = demotion_stats_for_log if logger and block_stage is not None and blockers_for_log: block_signature = (block_stage, tuple(blockers_for_log)) if block_signature != last_block_log: + stats_desc = "" + if stats_for_log: + formatted = [] + for stat in stats_for_log: + mean_val = stat.get("success_mean") + mean_str = f"{mean_val:.3f}" if mean_val is not None else "n/a" + formatted.append( + f"{stat.get('dataset')}: mean={mean_str}, thr={stat.get('threshold', 0.0):.3f}, " + f"rel={stat.get('relaxation', 0.0):.3f}, window={stat.get('window')}, " + f"samples={stat.get('total_samples')}, status={stat.get('status', 'n/a')}" + ) + stats_desc = " | stats: " + "; ".join(formatted) logger.info( - "Curriculum stage %d gated by: %s", + "Curriculum stage %d gated by: %s%s", block_stage, "; ".join(blockers_for_log), + stats_desc, ) last_block_log = block_signature elif stage_index >= desired_stage_index: last_block_log = None if logger and stage_index != current_stage: + base_weight = max(0.0, 1.0 - hard_weight - medium_weight) logger.info( - "Curriculum stage %d active (samples_processed=%d, hard_weight=%.3f)", + "Curriculum stage %d active (samples_processed=%d, base=%.3f, medium=%.3f, hard=%.3f)", stage_index, samples_processed, + base_weight, + medium_weight, hard_weight, ) current_stage = stage_index From a9d3c70b1c961b6be34defb77a0a8e9162b6a8b8 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:40:17 +0000 Subject: [PATCH 08/12] simplify stage tracking --- pipelinerl/actor.py | 567 ++++++++++++-------------------------------- 1 file changed, 156 insertions(+), 411 deletions(-) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index e309bc20..88bfc193 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -7,11 +7,11 @@ import queue import random import time -from collections import defaultdict, deque +from collections import Counter, defaultdict from multiprocessing.managers import SharedMemoryManager from pathlib import Path from queue import Empty -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional import aiohttp import hydra @@ -21,8 +21,6 @@ from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM -from typing import Dict, List, Optional - import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb from pipelinerl.rollouts import BaseMetrics, RolloutResult @@ -132,58 +130,6 @@ def get_stats(self): -class CurriculumSuccessTracker: - def __init__(self) -> None: - self._buffers: dict[str, deque[int]] = {} - self._max_windows: dict[str, int] = {} - self._total_counts: defaultdict[str, int] = defaultdict(int) - - def ensure_window(self, dataset: str, window: int) -> None: - if window <= 0: - window = 1 - current = self._max_windows.get(dataset, 0) - if window <= current: - return - existing = self._buffers.get(dataset, deque(maxlen=window)) - if existing.maxlen != window: - new_buffer = deque(existing, maxlen=window) - else: - new_buffer = existing - self._buffers[dataset] = new_buffer - self._max_windows[dataset] = window - - def update(self, dataset: str, success_values: list[int | bool]) -> None: - if not success_values: - return - buffer = self._buffers.get(dataset) - if buffer is None: - maxlen = self._max_windows.get(dataset, max(1, len(success_values))) - buffer = deque(maxlen=maxlen) - self._buffers[dataset] = buffer - self._max_windows[dataset] = maxlen - for value in success_values: - buffer.append(1 if bool(value) else 0) - self._total_counts[dataset] += 1 - - def success_mean(self, dataset: str, window: Optional[int] = None) -> Optional[float]: - buffer = self._buffers.get(dataset) - if buffer is None or not buffer: - return None - if window is None or window <= 0: - values = list(buffer) - else: - values = list(buffer)[-window:] - if len(values) < window: - return None - if not values: - return None - return sum(values) / len(values) - - def total_samples(self, dataset: str) -> int: - return self._total_counts.get(dataset, 0) - - - def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) @@ -191,28 +137,14 @@ def make_stats_dict() -> dict: def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: raw_schedule = curriculum_cfg.get("schedule", []) if not raw_schedule: - return [{"step": 0, "hard_weight": 0.0, "thresholds": []}] + return [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] if not isinstance(raw_schedule, list): raw_schedule = [raw_schedule] - default_cooldown = curriculum_cfg.get("default_promotion_cooldown_samples", 8000) - try: - default_cooldown = int(default_cooldown) - except (TypeError, ValueError): - default_cooldown = 8000 - if default_cooldown < 0: - default_cooldown = 0 - - default_hysteresis = curriculum_cfg.get("default_threshold_hysteresis", 0.02) - try: - default_hysteresis = float(default_hysteresis) - except (TypeError, ValueError): - default_hysteresis = 0.02 - if default_hysteresis < 0.0: - default_hysteresis = 0.0 - parsed_schedule: list[dict] = [] for entry in raw_schedule: step = int(entry.get("step", 0)) + if step < 0: + step = 0 hard_weight = float(entry.get("hard_weight", 0.0)) hard_weight = max(0.0, min(1.0, hard_weight)) medium_weight_value = entry.get("medium_weight", 0.0) @@ -222,70 +154,106 @@ def parse_curriculum_schedule(curriculum_cfg) -> list[dict]: medium_weight = 0.0 medium_weight = max(0.0, min(1.0, medium_weight)) weight_sum = medium_weight + hard_weight - max_non_base = 0.85 - if weight_sum > max_non_base and weight_sum > 0: - scale = max_non_base / weight_sum + if weight_sum > 1.0 and weight_sum > 0.0: + scale = 1.0 / weight_sum medium_weight *= scale hard_weight *= scale - demotion_patience_value = entry.get("demotion_patience", 1) - try: - demotion_patience = int(demotion_patience_value) - except (TypeError, ValueError): - demotion_patience = 1 - if demotion_patience < 1: - demotion_patience = 1 - cooldown_value = entry.get( - "promotion_cooldown_samples", entry.get("cooldown_samples", default_cooldown) - ) - try: - promotion_cooldown_samples = int(cooldown_value) - except (TypeError, ValueError): - promotion_cooldown_samples = default_cooldown - if promotion_cooldown_samples < 0: - promotion_cooldown_samples = 0 - hysteresis_value = entry.get("threshold_hysteresis", default_hysteresis) - try: - threshold_hysteresis = float(hysteresis_value) - except (TypeError, ValueError): - threshold_hysteresis = default_hysteresis - threshold_hysteresis = max(0.0, threshold_hysteresis) - thresholds_cfg = entry.get("success_thresholds", []) or [] - if not isinstance(thresholds_cfg, list): - thresholds_cfg = [thresholds_cfg] - thresholds: list[dict] = [] - for threshold_entry in thresholds_cfg: - dataset = threshold_entry.get("dataset") - if not dataset: + ready_success_cfg = entry.get("ready_success") or [] + if not isinstance(ready_success_cfg, list): + ready_success_cfg = [ready_success_cfg] + ready_success: list[dict] = [] + for cond in ready_success_cfg: + if not isinstance(cond, dict): continue - threshold_value = float(threshold_entry.get("threshold", 1.0)) - window = int(threshold_entry.get("window", threshold_entry.get("window_size", 0) or 1)) - if window <= 0: - window = 1 - min_samples_value = threshold_entry.get("min_samples") - min_samples = int(min_samples_value) if min_samples_value is not None else None - thresholds.append( + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + try: + threshold = float(cond.get("threshold", 1.0)) + except (TypeError, ValueError): + threshold = 1.0 + ready_success.append( { "dataset": dataset, - "threshold": threshold_value, - "window": window, - "min_samples": min_samples, + "metric": metric, + "threshold": threshold, } ) + patience_value = entry.get("ready_patience", 1) + try: + ready_patience = max(1, int(patience_value)) + except (TypeError, ValueError): + ready_patience = 1 parsed_schedule.append( { "step": step, - "hard_weight": hard_weight, "medium_weight": medium_weight, - "thresholds": thresholds, - "demotion_patience": demotion_patience, - "promotion_cooldown_samples": promotion_cooldown_samples, - "threshold_hysteresis": threshold_hysteresis, + "hard_weight": hard_weight, + "ready_success": ready_success, + "ready_patience": ready_patience, } ) parsed_schedule.sort(key=lambda item: item["step"]) return parsed_schedule +def advance_curriculum_stage( + schedule: list[dict], + stage_state: dict, + samples_processed: int, + stats: dict, + logger: logging.Logger | None = None, +) -> None: + if not schedule or stage_state is None: + return + current_idx = int(stage_state.get("index", 0)) + ready_counts = stage_state.setdefault("ready_counts", {}) + advanced = False + + while current_idx + 1 < len(schedule): + next_idx = current_idx + 1 + stage_cfg = schedule[next_idx] + min_step = stage_cfg.get("step", 0) + if samples_processed < min_step: + break + + ready_conditions: list[dict] = stage_cfg.get("ready_success") or [] + patience = max(1, int(stage_cfg.get("ready_patience", 1))) + + if ready_conditions: + all_pass = True + for cond in ready_conditions: + dataset = cond.get("dataset") + metric = cond.get("metric", "success_mean") + threshold = float(cond.get("threshold", 1.0)) + if dataset: + metric_key = f"{dataset}/{metric}" + else: + metric_key = metric + value = stats.get(metric_key) + if value is None or value < threshold: + all_pass = False + break + if all_pass: + ready_counts[next_idx] = ready_counts.get(next_idx, 0) + 1 + else: + ready_counts[next_idx] = 0 + break + if ready_counts[next_idx] < patience: + break + current_idx = next_idx + stage_state["index"] = current_idx + ready_counts.pop(next_idx, None) + advanced = True + if logger: + logger.info( + "Curriculum stage %d activated (samples_processed=%d)", + current_idx, + samples_processed, + ) + if not advanced: + stage_state["index"] = current_idx + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -329,9 +297,9 @@ async def rollout_and_maybe_produce_result( llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - logger.info(f"Starting rollout policy for problem {problem['id']}") + logger.debug(f"Starting rollout policy for problem {problem['id']}") rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session) - logger.info(f"Finished rollout policy for problem {problem['id']}") + logger.debug(f"Finished rollout policy for problem {problem['id']}") rollout_result.model_version = model_version token_count += get_number_of_tokens_in_result(rollout_result) # Make a group id that will be different from groups made by another rollout maker @@ -378,7 +346,7 @@ async def rollout_and_maybe_produce_result( if finished_rollouts > old_finished_rollouts: old_finished_rollouts = finished_rollouts save_debug_line({"rollouts_finished": finished_rollouts, "tokens_produced": token_count, "dt": time.time() - start_time, "token_speed": token_count / (time.time() - start_time)}) - logger.info( + logger.debug( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " f"groups in progress: {len(group_rollouts)}, " @@ -454,9 +422,8 @@ def curriculum_iter( trainer_state: TrainerState, curriculum_cfg: DictConfig, logger: logging.Logger | None = None, - success_tracker: CurriculumSuccessTracker | None = None, - stage_state: Optional[dict] = None, parsed_schedule: Optional[list[dict]] = None, + stage_state: Optional[dict] = None, ): curriculum_obj = ( OmegaConf.to_container(curriculum_cfg, resolve=True) @@ -486,304 +453,72 @@ def curriculum_iter( medium_pool = [problem for problem in problems if problem.get("dataset") in medium_names] hard_pool = [problem for problem in problems if problem.get("dataset") in hard_names] - if not hard_pool: - if logger: - logger.warning( - "Curriculum enabled but no problems matched hard_datasets list; falling back to base sampling" - ) - yield from random_iter(problems) - return - - if medium_names and not medium_pool and logger: - logger.warning( - "Curriculum medium_datasets specified but no problems matched; medium weighting will be ignored" - ) + if not base_pool and medium_pool: + base_pool = list(medium_pool) + medium_pool = [] if not base_pool: if logger: - logger.warning("Curriculum enabled but base pool is empty; falling back to medium or hard datasets") - if medium_pool: - base_pool = list(medium_pool) - medium_pool = [] - else: - base_pool = hard_pool + logger.warning("Curriculum enabled but no matching datasets were found; falling back to random sampling") + yield from random_iter(problems) + return schedule = parsed_schedule or parse_curriculum_schedule(curriculum_obj) - if success_tracker: - for stage in schedule: - for threshold in stage["thresholds"]: - success_tracker.ensure_window(threshold["dataset"], threshold["window"]) - - def stage_ready(stage_cfg: dict, relaxation: float = 0.0) -> tuple[bool, list[str], bool, list[dict]]: - if not stage_cfg["thresholds"] or success_tracker is None: - return True, [], False, [] - blockers: list[str] = [] - threshold_blocked = False - stats: list[dict] = [] - for threshold in stage_cfg["thresholds"]: - dataset = threshold["dataset"] - threshold_value = threshold["threshold"] - window = threshold["window"] - min_samples = threshold.get("min_samples") - total_samples = success_tracker.total_samples(dataset) - if min_samples is not None: - if total_samples < min_samples: - blockers.append( - f"{dataset}: waiting for {min_samples} samples (have {total_samples})" - ) - stats.append( - { - "dataset": dataset, - "success_mean": None, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "min_samples", - } - ) - continue - success_mean_value = success_tracker.success_mean(dataset, window) - if success_mean_value is None: - blockers.append(f"{dataset}: insufficient window data (need {window})") - stats.append( - { - "dataset": dataset, - "success_mean": None, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "insufficient_window", - } - ) - continue - adjusted_threshold = threshold_value - relaxation - if success_mean_value < adjusted_threshold: - threshold_blocked = True - blockers.append( - f"{dataset}: success_mean {success_mean_value:.3f} < {adjusted_threshold:.3f} (threshold={threshold_value:.3f}, relaxation={relaxation:.3f}, window={window})" - ) - stats.append( - { - "dataset": dataset, - "success_mean": success_mean_value, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "threshold", - } - ) - continue - stats.append( - { - "dataset": dataset, - "success_mean": success_mean_value, - "threshold": threshold_value, - "relaxation": relaxation, - "window": window, - "min_samples": min_samples, - "total_samples": total_samples, - "status": "ok", - } - ) - return (len(blockers) == 0), blockers, threshold_blocked, stats + if not schedule: + schedule = [{"step": 0, "medium_weight": 0.0, "hard_weight": 0.0}] - current_stage = -1 - last_block_log: tuple[int, tuple[str, ...]] | None = None - if stage_state is None: - stage_state = {"index": 0} - stage_state.setdefault("consecutive_failures", {}) - stage_state.setdefault("last_promotion_samples", -math.inf) + current_stage_index = stage_state.get("index", 0) if stage_state is not None else 0 + last_logged_stage: Optional[int] = None while True: samples_processed = trainer_state.samples_processed or 0 - desired_stage_index = 0 - - for idx, stage_cfg in enumerate(schedule): - step = stage_cfg["step"] - if samples_processed >= step: - desired_stage_index = idx - else: - break - - current_stage = int(stage_state.get("index", 0)) - if current_stage < 0: - current_stage = 0 - if current_stage >= len(schedule): - current_stage = len(schedule) - 1 - prev_stage = current_stage - - stage_index = min(current_stage, desired_stage_index) - promotion_blockers: list[str] = [] - promotion_stats_for_log: list[dict] = [] - - # Walk backwards until the current stage is ready (or we reach stage 0) - while stage_index > 0: - ready, _, _, _ = stage_ready(schedule[stage_index], relaxation=0.0) - if ready: - break - stage_index -= 1 - - ready, current_blockers, _, current_stats = stage_ready( - schedule[stage_index], relaxation=0.0 - ) - if not ready and stage_index > 0: - # If even after walking back we are not ready, fall back further until 0 - while stage_index > 0 and not ready: - stage_index -= 1 - ready, current_blockers, _, current_stats = stage_ready( - schedule[stage_index], relaxation=0.0 - ) + max_stage_allowed = current_stage_index + while ( + max_stage_allowed + 1 < len(schedule) + and samples_processed >= schedule[max_stage_allowed + 1]["step"] + ): + max_stage_allowed += 1 - # Attempt to promote by at most one stage towards the desired stage - if stage_index < desired_stage_index: - next_index = stage_index + 1 - next_ready, blockers, _, next_stats = stage_ready( - schedule[next_index], relaxation=0.0 - ) - if next_ready: - stage_index = next_index - current_blockers = [] - current_stats = next_stats - else: - promotion_blockers = blockers - promotion_stats_for_log = next_stats - promotion_block_stage: int | None = None - promotion_blockers_for_log: list[str] = [] - if stage_index < desired_stage_index: - promotion_block_stage = stage_index + 1 - promotion_blockers_for_log = promotion_blockers or current_blockers - if promotion_blockers_for_log: - if not promotion_stats_for_log: - promotion_stats_for_log = current_stats - else: - promotion_block_stage = None - - candidate_stage_index = stage_index - if candidate_stage_index > prev_stage + 1: - candidate_stage_index = prev_stage + 1 - - cooldown_blockers: list[str] = [] - cooldown_stats: list[dict] = [] - last_promotion_samples = stage_state.get("last_promotion_samples", -math.inf) - if candidate_stage_index > prev_stage: - cooldown_required = schedule[candidate_stage_index].get( - "promotion_cooldown_samples", 0 - ) - samples_since_promotion = samples_processed - last_promotion_samples - if samples_since_promotion < cooldown_required: - cooldown_blockers = [ - ( - f"promotion cooldown active: {samples_since_promotion} / " - f"{cooldown_required} samples since last promotion" - ) - ] - cooldown_stats = current_stats - candidate_stage_index = prev_stage - else: - stage_state["last_promotion_samples"] = samples_processed - - stage_index = candidate_stage_index - - blockers_for_log: list[str] = [] - block_stage: int | None = None - failure_counts: dict[int, int] = stage_state.setdefault("consecutive_failures", {}) - demotion_stats_for_log: list[dict] = [] - demotion_cancelled = False - if prev_stage > stage_index: - _, prev_blockers, prev_threshold_blocked, prev_stats = stage_ready( - schedule[prev_stage], - relaxation=schedule[prev_stage].get("threshold_hysteresis", 0.0), - ) - patience = schedule[prev_stage].get("demotion_patience", 1) - if prev_threshold_blocked and patience > 1: - failures = failure_counts.get(prev_stage, 0) + 1 - if failures < patience: - failure_counts[prev_stage] = failures - stage_index = prev_stage - demotion_cancelled = True - block_stage = prev_stage - blockers_for_log = prev_blockers - promotion_stats_for_log = prev_stats - demotion_stats_for_log = prev_stats - else: - failure_counts[prev_stage] = 0 - else: - failure_counts[prev_stage] = 0 + if stage_state is not None: + desired_index = int(stage_state.get("index", 0)) + current_stage_index = max(0, min(desired_index, max_stage_allowed)) else: - failure_counts.setdefault(prev_stage, 0) - failure_counts[prev_stage] = 0 - - if not demotion_cancelled: - failure_counts.setdefault(stage_index, 0) - if stage_index != prev_stage: - failure_counts[stage_index] = 0 + current_stage_index = max_stage_allowed - hard_weight = schedule[stage_index]["hard_weight"] - medium_weight = schedule[stage_index].get("medium_weight", 0.0) + stage_cfg = schedule[current_stage_index] + medium_weight = stage_cfg.get("medium_weight", 0.0) + hard_weight = stage_cfg.get("hard_weight", 0.0) if not medium_pool: medium_weight = 0.0 + if not hard_pool: + hard_weight = 0.0 + base_weight = max(0.0, 1.0 - medium_weight - hard_weight) + weight_sum = base_weight + medium_weight + hard_weight + if weight_sum == 0.0: + base_weight = 1.0 + medium_weight = 0.0 + hard_weight = 0.0 - stats_for_log: list[dict] = [] - if block_stage is None and promotion_block_stage is not None: - block_stage = promotion_block_stage - blockers_for_log = promotion_blockers_for_log - stats_for_log = promotion_stats_for_log - if block_stage is None and cooldown_blockers: - block_stage = prev_stage + 1 if prev_stage + 1 < len(schedule) else prev_stage - blockers_for_log = cooldown_blockers - stats_for_log = cooldown_stats - if not stats_for_log and demotion_stats_for_log: - stats_for_log = demotion_stats_for_log - - if logger and block_stage is not None and blockers_for_log: - block_signature = (block_stage, tuple(blockers_for_log)) - if block_signature != last_block_log: - stats_desc = "" - if stats_for_log: - formatted = [] - for stat in stats_for_log: - mean_val = stat.get("success_mean") - mean_str = f"{mean_val:.3f}" if mean_val is not None else "n/a" - formatted.append( - f"{stat.get('dataset')}: mean={mean_str}, thr={stat.get('threshold', 0.0):.3f}, " - f"rel={stat.get('relaxation', 0.0):.3f}, window={stat.get('window')}, " - f"samples={stat.get('total_samples')}, status={stat.get('status', 'n/a')}" - ) - stats_desc = " | stats: " + "; ".join(formatted) - logger.info( - "Curriculum stage %d gated by: %s%s", - block_stage, - "; ".join(blockers_for_log), - stats_desc, - ) - last_block_log = block_signature - elif stage_index >= desired_stage_index: - last_block_log = None - - if logger and stage_index != current_stage: - base_weight = max(0.0, 1.0 - hard_weight - medium_weight) + if logger and last_logged_stage != current_stage_index: logger.info( "Curriculum stage %d active (samples_processed=%d, base=%.3f, medium=%.3f, hard=%.3f)", - stage_index, + current_stage_index, samples_processed, base_weight, medium_weight, hard_weight, ) - current_stage = stage_index + last_logged_stage = current_stage_index - stage_state["index"] = stage_index + if stage_state is not None: + stage_state["index"] = current_stage_index choice = random.random() - if hard_pool and choice < hard_weight: + hard_cutoff = hard_weight + medium_cutoff = hard_cutoff + medium_weight + if hard_pool and choice < hard_cutoff: yield random.choice(hard_pool) - elif medium_pool and choice < hard_weight + medium_weight: + elif medium_pool and choice < medium_cutoff: yield random.choice(medium_pool) else: yield random.choice(base_pool) @@ -818,7 +553,7 @@ def __init__( self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) - self.curriculum_tracker: CurriculumSuccessTracker | None = None + self.curriculum_schedule: list[dict] | None = None self.curriculum_stage_state: dict | None = None self.smm: SharedMemoryManager | None = None @@ -867,6 +602,7 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) + self.answer_status_counts = Counter() def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -893,15 +629,15 @@ def update_stats(self, rollout_results: List[RolloutResult]): for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v - if k == "success" and self.curriculum_tracker: - self.curriculum_tracker.update(dataset_name, v) elif isinstance(v, float) | isinstance(v, bool) | isinstance(v, int): self.stats[k][dataset_name][group_id].append(v) - if k == "success" and self.curriculum_tracker: - self.curriculum_tracker.update(dataset_name, [v]) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") + status = getattr(result, "answer_status", None) + if status in {"correct", "wrong", "unparsable", "no_answer"}: + self.answer_status_counts[status] += 1 + if self.sliding_aggregator: prompt_length_tokens = [ training_text.prompt_tokens @@ -949,27 +685,23 @@ def run(self, dataset: list[tuple[str, dict]]): else curriculum_cfg ) parsed_schedule = parse_curriculum_schedule(curriculum_obj) - self.curriculum_tracker = CurriculumSuccessTracker() - for stage in parsed_schedule: - for threshold in stage["thresholds"]: - self.curriculum_tracker.ensure_window(threshold["dataset"], threshold["window"]) + self.curriculum_schedule = parsed_schedule self.curriculum_stage_state = {"index": 0} problem_iter = curriculum_iter( dataset, trainer_state=self.trainer_state, curriculum_cfg=curriculum_cfg, logger=logger, - success_tracker=self.curriculum_tracker, - stage_state=self.curriculum_stage_state, parsed_schedule=parsed_schedule, + stage_state=self.curriculum_stage_state, ) else: problem_iter = random_iter(dataset) - self.curriculum_tracker = None + self.curriculum_schedule = None self.curriculum_stage_state = None else: problem_iter = sequential_iter(dataset) - self.curriculum_tracker = None + self.curriculum_schedule = None self.curriculum_stage_state = None assert self.trainer_state.propagated_weight_version is not None @@ -1136,8 +868,21 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 - if self.curriculum_stage_state is not None: + if self.curriculum_schedule and self.curriculum_stage_state is not None: + advance_curriculum_stage( + self.curriculum_schedule, + self.curriculum_stage_state, + self.trainer_state.samples_processed or 0, + stats, + logger, + ) stats["curriculum_stage_active"] = self.curriculum_stage_state.get("index", 0) + + total_status = sum(self.answer_status_counts.values()) + if total_status: + for status, count in self.answer_status_counts.items(): + stats[f"{split_name}answer_status_{status}_count"] = count + stats[f"{split_name}answer_status_{status}_ratio"] = count / total_status if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) From 3e06b59c10d5c93f55252f18c3f27fbbace1083c Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:42:32 +0000 Subject: [PATCH 09/12] Track answer status --- pipelinerl/domains/math/rollouts.py | 35 ++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 862375d0..1f70c11d 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -1,7 +1,9 @@ +import json +import logging +import random import re import time -import random -import logging +from pathlib import Path import aiohttp from omegaconf import DictConfig @@ -14,6 +16,7 @@ from pipelinerl.async_llm import llm_async_generate, make_training_text from .verifier_api import verify_answer_rpc +logger = logging.getLogger(__name__) class Metrics(BaseMetrics): penalty: float @@ -139,6 +142,30 @@ def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> return ((max_length - buffer_tokens) - sequence_length) / buffer_tokens return 0. + +def log_answer_status(cfg: DictConfig, problem: dict, answer_status: str, reward: float, latency: float) -> None: + """ + Metric logging for answer status - correct, wrong, no_answer, unparsable + """ + try: + log_dir = Path(cfg.output_dir) if cfg.output_dir else None + if not log_dir: + return + log_path = log_dir / "answer_status.jsonl" + record = { + "t": time.time(), + "problem_id": problem.get("id"), + "dataset": problem.get("dataset"), + "answer_status": answer_status, + "reward": reward, + "latency": latency, + } + with log_path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record)) + handle.write("\n") + except Exception: + logger.debug("Failed to append answer status log", exc_info=True) + async def generate_math_rollout( cfg: DictConfig, llm: TrainableLLM, @@ -199,6 +226,7 @@ async def generate_math_rollout( ) reward += overlong_penalty trace.reward = reward + log_answer_status(cfg, problem, answer_status, reward, latency) # Prefer backend-provided finish reason if available; normalize for comparisons if isinstance(trace.metadata, dict): @@ -249,6 +277,7 @@ async def generate_math_rollout( return RolloutResult( training_texts=[trace], metrics=metrics, - latency=latency, + latency=latency, dataset_name=problem.get("dataset"), + answer_status=answer_status, ) From 7a63df9d444ce75bf96a5ae2346748009a407a34 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:48:07 +0000 Subject: [PATCH 10/12] Remove delimiter tags --- pipelinerl/domains/math/verifier_api.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pipelinerl/domains/math/verifier_api.py b/pipelinerl/domains/math/verifier_api.py index 9fa9b4ef..59f66dfc 100644 --- a/pipelinerl/domains/math/verifier_api.py +++ b/pipelinerl/domains/math/verifier_api.py @@ -1,6 +1,7 @@ import time import requests import asyncio +import re from concurrent.futures import ProcessPoolExecutor import aiohttp import uvicorn @@ -61,6 +62,18 @@ def timeout_handler(signum, frame): signal.signal(signal.SIGALRM, original_handler) +DELIMITER_STR = re.compile(r"\[END FINAL RESPONSE\]", flags=re.IGNORECASE) + + +def strip_delimiter_strings(text: str) -> str: + if not text: + return text + stripped = DELIMITER_STR.sub("", text) + # Remove lines that became empty after sentinel stripping to avoid parsing noise + cleaned_lines = [line for line in stripped.splitlines() if line.strip()] + return "\n".join(cleaned_lines) + + def verify_answer(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: """ Checks if a predicted answer matches a gold (correct) answer by making a request to the math_verify package. @@ -88,13 +101,13 @@ def verify_answer(prediction: str, gold: str, strict: bool = True, max_predictio def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_length: int = 1000) -> str: - import re - try: # Input Sanitization / Validation if not isinstance(prediction, str) or not isinstance(gold, str): raise ValueError("Prediction and gold must be strings") + prediction = strip_delimiter_strings(prediction) + # Try extracting from \boxed{...} first boxed_start = prediction.rfind("\\boxed{") @@ -109,7 +122,7 @@ def verify_math(prediction: str, gold: str, strict: bool = True, max_prediction_ # Fallback: look for ... tags answer_match = re.findall(r"(.*?)", prediction, re.DOTALL) if answer_match: - extracted_prediction = answer_match[-1].strip() # last one if multiple + extracted_prediction = strip_delimiter_strings(answer_match[-1].strip()) # last one else: raise NoAnswerException() @@ -225,5 +238,3 @@ async def health(): return JSONResponse(content={"status": "ok"}) uvicorn.run(app, host="0.0.0.0", port=port, timeout_keep_alive=60) - - From fc1ff86dfcc0cb24a96b64734975ba2557850e77 Mon Sep 17 00:00:00 2001 From: rafapi Date: Tue, 21 Oct 2025 11:48:47 +0000 Subject: [PATCH 11/12] Track ans status in rollout --- pipelinerl/rollouts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pipelinerl/rollouts.py b/pipelinerl/rollouts.py index dcb27f2d..6ee2b6c2 100644 --- a/pipelinerl/rollouts.py +++ b/pipelinerl/rollouts.py @@ -64,3 +64,4 @@ class RolloutResult(BaseModel): model_version: int | None = None dataset_name: str | None = None group_id: str | None = None + answer_status: str | None = None From e6c9463cedb26b110c639e58858f5a3f735970fb Mon Sep 17 00:00:00 2001 From: rafapi Date: Mon, 27 Oct 2025 16:42:22 +0000 Subject: [PATCH 12/12] rollout counter --- pipelinerl/actor.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index 88bfc193..5e7eeb18 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -603,6 +603,7 @@ def init_stats(self): self.model_versions_list = [] self.sliding_stats = defaultdict(list) self.answer_status_counts = Counter() + self.dataset_sample_counts = Counter() def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} @@ -637,6 +638,8 @@ def update_stats(self, rollout_results: List[RolloutResult]): status = getattr(result, "answer_status", None) if status in {"correct", "wrong", "unparsable", "no_answer"}: self.answer_status_counts[status] += 1 + if dataset_name: + self.dataset_sample_counts[str(dataset_name)] += 1 if self.sliding_aggregator: prompt_length_tokens = [ @@ -883,6 +886,12 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): for status, count in self.answer_status_counts.items(): stats[f"{split_name}answer_status_{status}_count"] = count stats[f"{split_name}answer_status_{status}_ratio"] = count / total_status + total_rollouts = sum(self.dataset_sample_counts.values()) + if total_rollouts: + stats["dataset_rollouts_total"] = total_rollouts + for dataset, count in self.dataset_sample_counts.items(): + stats[f"{dataset}/rollout_count"] = count + stats[f"{dataset}/rollout_ratio"] = count / total_rollouts if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats)