diff --git a/.gitignore b/.gitignore index 476aab77..1469bc67 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ celerybeat.pid # SageMath parsed files *.sage.py +node_modules/ # Environments .env @@ -185,4 +186,4 @@ results results/ data/ cache/ -dump.rdb \ No newline at end of file +dump.rdb diff --git a/conf/base.yaml b/conf/base.yaml index 3d426f4c..5bf30c59 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -5,6 +5,7 @@ defaults: - _self_ seed: 42 +use_ray: false finetune: seed: ${..seed} @@ -23,9 +24,9 @@ preprocess: input: actor output: training_data n_workers: 8 - chunk_n_groups: 2 + chunk_n_groups: 8 # queue for loaded raw groups - raw_queue_size: 8 + raw_queue_size: 128 # queue for processed chunks of multiple groups input_queue_size: 32 # queue for ready chunks for multiple groups @@ -47,7 +48,7 @@ llm: temperature: 1.0 test_llm: parameters: - max_tokens: 16000 + max_tokens: 8192 temperature: 1.0 top_p: 0.95 top_k: 50 @@ -67,6 +68,7 @@ vllm_config: tensor-parallel-size: 1 pipeline-parallel-size: 1 generation-config: vllm + max_model_len: 16000 world: replicas: 1 @@ -75,10 +77,13 @@ world: preprocessor_fraction: 0 finetune_fraction: 4 - env_replicas: 2 + # Number of environment servers per actor VLLM server + env_replicas_per_actor: 1 actor_group_port: 9000 environment_start_port: 7777 +# Remote vs embedded environment execution strategy + environment_mode: remote # this will be autocreated based on the config jobs: [] diff --git a/conf/finetune/base.yaml b/conf/finetune/base.yaml index 237e6d56..6fb09310 100644 --- a/conf/finetune/base.yaml +++ b/conf/finetune/base.yaml @@ -36,7 +36,7 @@ learning_rate: 1e-6 # How much to clip the gradient (no clipping if null) gradient_clipping_threshold: 0.3 # Learning rate scheduler type (indexed by completed_steps). -lr_scheduler_type: cosine # could be cosine, constant_with_warmup +lr_scheduler_type: constant # could be cosine, constant_with_warmup # Number of warmup (completed) steps in the learning rate schedule. num_warmup_steps: 50 # Number of gradient accumulation steps. diff --git a/conf/mcp.yaml b/conf/mcp.yaml new file mode 100644 index 00000000..5f8680b1 --- /dev/null +++ b/conf/mcp.yaml @@ -0,0 +1,155 @@ +defaults: + - base + - override finetune: grpo + - _self_ + +use_ray: true + +llm: + use_cache: false + parameters: + max_tokens: 8192 + +test_llm: + parameters: + max_tokens: 8192 + +rewards: + correct_answer_not_finished: 0.0 + buffer_tokens: 2000 + +actor: + rollout_policy: pipelinerl.domains.mcp.generate_mcp_rollout_with_local_env + system_prompt: Please reason step by step, and put your final answer within \boxed{{}}. + rollout_workers: 64 + llm_max_rollouts: 256 + problem_queue_size: 256 + task_template: |- + {task} + shared_memory_entry_size: 200000000 + +preprocess: + shared_memory_entry_size: 2000000000 + +finetune: + seq_length: 32000 + seq_parallel: 8 + +dataset_loader: pipelinerl.domains.math.load_datasets +train_dataset_names: +- open_reasoner_zero_57k +- open_reasoner_zero_extended_72k +test_dataset_names: + - aime_2025 + +vllm_config: + use_v1: true + vllm_kwargs: + enable-auto-tool-choice: "" + tool-call-parser: rl_tool + tool-parser-plugin: ${hydra:runtime.cwd}/pipelinerl/rl_tool_parser_plugin.py + max-num-seqs: 256 + max-num-batched-tokens: 32000 + max_model_len: 32000 + gpu-memory-utilization: 0.9 + +environment: + _target_: tapeagents.mcp.MCPEnvironment + config_path: ${hydra:runtime.cwd}/conf/mcp/python.json + tools_whitelist: + - run_python_code + read_timeout_seconds: 600 + use_cache: false + + +world: + env_replicas_per_actor: 8 + environment_mode: embedded + +agent_max_loops: 3 +agent: + _target_: tapeagents.agent.Agent + name : mcp_agent + max_iterations: 3 + store_llm_calls: true + templates: + system_prompt: | + You are a math-focused AI Agent. Solve problems by combining clear symbolic reasoning + with short, deterministic Python code. + Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. + Always present the final answer in LaTeX \boxed{{}}. + Do not express emotions or opinions about user questions. + + Workflow: + 1. Draft a brief plan in plain text. + 2. Execute one run_python_code call to compute or verify the result. + 3. Finalize by calling MathAnswer with the LaTeX-formatted answer. + + Python execution policy (run_python_code): + - Use Python strictly for pure computation to verify and validate the final answer. + - No network, file system, OS or environment access. + - Keep snippets minimal and self-contained; avoid large outputs and long-running loops; print only the final result. + + Validation: + - Cross-check results (alternative derivation, invariants, higher precision) before finalizing. + - If execution fails, propose the minimal fix and retry. + Keep replies direct and avoid unnecessary text. + allowed_tools: | + You can call the following tools: + {tools_description} + - run_python_code: deterministic math code; print only the final value. + - MathAnswer: return the LaTeX \boxed{{}} answer when the solution is verified. + Always verify with run_python_code before invoking MathAnswer. + thought_format: | + Important! Respond with the plain text, do not include any JSON or code. + Do not output anything besides what I asked in this message. + allowed_steps: | + Workflow summary: + - Plan briefly in plain text. + - Call run_python_code exactly once per loop to compute/verify. + - Finish with a single MathAnswer tool call carrying the \boxed{{}} result. + format: | + For finalization, reply with a single short sentence that ends in the \boxed{{}} answer, + immediately followed by the MathAnswer function call containing the same \boxed{{}} value. + Never emit unrelated JSON wrappers or duplicate the final thought. + + + nodes: + - _target_: tapeagents.nodes.StandardNode + name: plan + system_prompt: ${agent.templates.system_prompt} + guidance: | + Produce a concise math plan (formulas/checks). You will ALWAYS verify by executing Python code. + ${agent.templates.thought_format} + steps_prompt: ${agent.templates.allowed_tools} + trim_obs_except_last_n: 2 + + - _target_: tapeagents.nodes.StandardNode + name: code + system_prompt: ${agent.templates.system_prompt} + guidance: | + ALWAYS call run_python_code once to compute/verify the result. + Use exact, deterministic code; print only the final scalar or tuple. + If code fails, fix minimally and call run_python_code again after reviewing the error. + use_known_actions: true + use_function_calls: true + trim_obs_except_last_n: 2 + + - _target_: tapeagents.nodes.StandardNode + name: finalize + system_prompt: ${agent.templates.system_prompt} + guidance: | + Read the last Python stdout value. First, state the answer in one short sentence that ends with LaTeX \boxed{{}}. + Immediately after that sentence, call the MathAnswer tool exactly once with: + name: MathAnswer + arguments: {"answer": ""} + Do not add any extra text around the tool call. Once the sentence is emitted, return only the MathAnswer function call. + steps: + - pipelinerl.domains.mcp.steps.MathAnswer + use_known_actions: true + use_function_calls: true + trim_obs_except_last_n: 2 + next_node: code + +model_path: Qwen/Qwen3-8B +# model_path: /mnt/llmd/base_models/ServiceNow-AI/7_9_25_14b_text_reasoning_sft diff --git a/conf/mcp/python.json b/conf/mcp/python.json new file mode 100644 index 00000000..fcbb4dcf --- /dev/null +++ b/conf/mcp/python.json @@ -0,0 +1,11 @@ +{ + "mcpServers": { + "python_exec": { + "command": "bash", + "args": [ + "-c", + "deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python stdio" + ] + } + } +} \ No newline at end of file diff --git a/conf/miniwob.yaml b/conf/miniwob.yaml index a5bf8bc2..1454e774 100644 --- a/conf/miniwob.yaml +++ b/conf/miniwob.yaml @@ -1,34 +1,32 @@ defaults: - base + - override streams: redis + - override finetune: ppo + - _self_ world: - actor_fraction: 4 - preprocessor_fraction: 1 - finetune_fraction: 3 + actor_fraction: 2 + preprocessor_fraction: 0 + finetune_fraction: 6 # debug: # mode: actor save_tapes: False -output_dir: results/miniwob_debug/${now:%Y-%m-%d}/${now:%H-%M-%S} +output_dir: results/miniwob/${now:%Y-%m-%d}/${now:%H-%M-%S} model_path: meta-llama/Llama-3.1-8B-Instruct finetune: - save_checkpoint_steps: 10 - seq_length: 4096 + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples train_batch_size: 1 gradient_accumulation_passes: 1024 - learning_rate: 1e-6 - optim: adamw_torch - rl: - kl_coef: 0.01 # GRPO beta coefficient - reward_minus_kl_coef: 0.0 # RLOO beta coefficient - use_advantages: true - algo: grpo + +eval_every_n_versions: 10240 # 1024 effective bs * 10 "optim steps" llm: parameters: - max_tokens: 3072 + max_tokens: 4096 # output tokens temperature: 1.0 test_llm: parameters: @@ -39,24 +37,37 @@ test_llm: vllm_config: vllm_kwargs: - enable-auto-tool-choice: "" - tool-call-parser: llama3_json # use hermes for qwen - chat_template: pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja # copy pasted from https://github.com/vllm-project/vllm/blob/main/examples/tool_chat_template_llama3.1_json.jinja - enforce-eager: "" # speed the actor llm startup a bit + max_model_len: 16384 # input + output tokens actor: - rollout_policy: pipelinerl.miniwob.rollouts.generate_miniwob_rollout + rollout_policy: pipelinerl.domains.miniwob.rollouts.generate_miniwob_rollout shared_memory_entry_size: 100000000 + llm_max_rollouts: 32 preprocess: - shared_memory_entry_size: 1000000000 + n_workers: 32 # Increase from 8 + chunk_n_groups: 8 # Increase from 2 for better throughput + # queue for loaded raw groups + raw_queue_size: 32 # Increase from 8 + # queue for processed chunks of multiple groups + input_queue_size: 64 # Increase from 32 + # queue for ready chunks for multiple groups + output_queue_size: 64 # Increase from 32 + # ring buffer to replace old samples with new ones when training is slow + ring_buffer_size: 1024 # Increase from 128 + # "virtual" sample queue per lead trainer + max_ready_samples_per_lead: 256 # Increase from 64 + shared_memory_entry_size: 1000000000 # Increase from 100M # AGENT CONFIGURATION agent_max_loops: 10 # max number of agent - environment interactions for each task +agent_attempts: 3 # number of attempts to run the agent (retry on errors) +rollout_timeout: 600 # overall timeout for entire rollout in seconds (10 minutes) +reward_computation: nico agent: _target_: tapeagents.agent.Agent name : web_agent - max_iterations: 4 # max number of iterations (make_prompt + llm? + generate_steps) for each loop + max_iterations: 4 # max number of iterations (make_prompt + llm + generate_steps) for each loop store_llm_calls: true templates: system_prompt: | @@ -65,50 +76,64 @@ agent: Keep your replies concise and direct. Prioritize clarity and avoid over-elaboration. You will be provided with the content of the current page and a task from the user. Do not express your emotions or opinions about the user question. - allowed_tools: | - You have access to the following tools: - {tools_description} - thought_format: | - Important! Respond with the plain text, do not include any JSON or code. - Do not output anything besides what I asked in this message. + allowed_steps: | + You are allowed to produce ONLY steps with the following json schemas: + {allowed_steps} + Do not reproduce schema when producing the steps, use it as a reference. + json_format: | + Important! Respond with very simple parsable JSON! + Do not use any special characters or code. Do not use new lines, tabs, or any other formatting inside the JSON. + Do not output anything besides one simple JSON object. nodes: - _target_: examples.rl_webagent.agent.WebNode name: set_goal system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the thought that describes the intended solution to the task. In the reasoning lines: + Produce the reasoning_thought step that describes the intended solution to the task. In the reasoning lines: - review the instructions from the user and the content of the page. - outline the main task to be accomplished and the steps to be taken to achieve it. - produce definiton of done, that will be checked later to verify if the task was completed. - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce only one reasoning_thought step! + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: examples.rl_webagent.agent.WebNode name: reflect system_prompt: ${agent.templates.system_prompt} guidance: | - Review the current state of the page and previous steps to find the best possible next action to accomplish the task. - Produce the reflection_thought to describe the current page state, reflect on your last action, describe what is left to do, and what will be the immediate next action. - Produce only one reflection_thought step! - ${agent.templates.thought_format} - steps_prompt: ${agent.templates.allowed_tools} + Produce the reasoning_thought step that describes the current state of the page, the previous actions, and what should be the next best action to accomplish the task. In the reasoning lines: + - think about which information could be relevant to the given task, note relevant BIDs and coordinates. + - describe the last action taken, what were its expected effects on the page, versus the actual effects you can observe. Are they the same or not? if not, what could have gone wrong? + - check if you are stuck with repeating the same action over and over again, if so, try something else and change the action. + - check if you think the task is done, if not give a detailed list of actions to do next to accomplish the task. + - finally, if the task is not done, describe the immediate next action to be performed and its expected effect on the page. + Produce only one reasoning_thought step! Be brief and to the point. You can skip some details if they are not relevant for this step. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} + steps: + - tapeagents.steps.ReasoningThought trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages max_chars_page_observation: 3000 # keep up to 3000 chars in PageObservation steps - _target_: examples.rl_webagent.agent.WebNode name: act system_prompt: ${agent.templates.system_prompt} guidance: | - Produce the single next tool call to be performed with the current page. - If you think that the task is solved, call the FinalAnswer. + Produce the next action to be performed with the current page. + If you think that the task is solved, produce the final_answer_action. You can interact with the page elements using their BIDs or coordinates as arguments for actions. HINTS: - You can use the BIDs of the elements or the mouse position in x, y coordinates to interact with them. - - To select value in a dropdown or combobox, ALWAYS use SelectOption tool. + - To select value in a dropdown or combobox, ALWAYS use select_action. - To click on a checkbox or radio button, ALWAYS use BID (or coordinates) of the corresponding Text and not the BID (or coordinates) of the element itself. - Press enter key to submit the search query. + - Always produce only one step at a time. + - Step kind is always lowercase and underscore separated. + ${agent.templates.json_format} + steps_prompt: ${agent.templates.allowed_steps} use_known_actions: true - use_function_calls: true steps: - examples.rl_webagent.steps.FinalAnswerAction trim_obs_except_last_n: 3 # keep the last 3 observations from the tape in prompt messages @@ -119,18 +144,18 @@ agent: # ENVIRONMENT CONFIGURATION start_attempts: 3 # number of attempts to start each task environment: - _target_: pipelinerl.miniwob.environment_server.WebEnvironmentServer - miniwob_url: file:///home/toolkit/miniwob-plusplus/miniwob/html/miniwob/ - n_envs: 64 + _target_: pipelinerl.domains.miniwob.environment_server.WebEnvironmentServer + miniwob_url: ??? + n_envs: 32 host: "0.0.0.0" - max_session_inactivity_secs: 300 + env_call_timeout: 60 # timeout for each environment call (e.g. start_task, act, etc.) web_env_target: examples.rl_webagent.environment.WebEnvironment - exp_path: ${output_dir}/env_server + exp_path: null headless: true observation_format: html # DATASET CONFIGURATION -dataset_loader: pipelinerl.miniwob.load_tasks.load_tasks +dataset_loader: pipelinerl.domains.miniwob.load_tasks.load_tasks dataset_loader_params: train_split: 0.6 # 0.6 of tasks for training, 0.4 for testing seeds: [0, 42, 1337, 900, 103] diff --git a/conf/miniwob_grpo.yaml b/conf/miniwob_grpo.yaml new file mode 100644 index 00000000..f6cfeed3 --- /dev/null +++ b/conf/miniwob_grpo.yaml @@ -0,0 +1,10 @@ +defaults: + - miniwob + - override finetune: grpo + - _self_ + +finetune: + seq_length: 16384 # input + output tokens + max_train_steps: 1000 # 1000 optim steps = 1000 * bs samples + train_batch_size: 1 + gradient_accumulation_passes: 1024 diff --git a/conf/miniwob_massimo_grpo.yaml b/conf/miniwob_massimo_grpo.yaml new file mode 100644 index 00000000..b61dcf32 --- /dev/null +++ b/conf/miniwob_massimo_grpo.yaml @@ -0,0 +1,15 @@ +defaults: + - miniwob_grpo + - _self_ + +train_dataset_names: + - massimo_train +test_dataset_names: + - massimo_test + +reward_computation: massimo + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/conf/miniwob_massimo_ppo.yaml b/conf/miniwob_massimo_ppo.yaml new file mode 100644 index 00000000..53703d56 --- /dev/null +++ b/conf/miniwob_massimo_ppo.yaml @@ -0,0 +1,15 @@ +defaults: + - miniwob + - _self_ + +train_dataset_names: + - massimo_train +test_dataset_names: + - massimo_test + +reward_computation: massimo + +finetune: + gradient_accumulation_passes: 512 + +eval_every_n_versions: 5120 # 512 effective bs * 10 "optim steps" diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index a85f156e..358b3797 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -1,27 +1,30 @@ import asyncio +import json import logging import math import multiprocessing as mp import os import queue -from queue import Empty import random import time from collections import defaultdict from multiprocessing.managers import SharedMemoryManager from pathlib import Path +from queue import Empty +from typing import Callable, Dict, List import aiohttp import hydra +import numpy as np +import ray import uvloop -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pydantic import BaseModel, Field from tapeagents.llms import TrainableLLM -from typing import Dict, List import wandb from pipelinerl.finetune.logging_ import flatten_dict_config, init_wandb -from pipelinerl.rollouts import RolloutResult, BaseMetrics +from pipelinerl.rollouts import BaseMetrics, RolloutResult from pipelinerl.shared_memory_array import SharedMemoryQueue from pipelinerl.state import TrainerState from pipelinerl.streams import ( @@ -42,6 +45,11 @@ logger = logging.getLogger(__name__) +def save_debug_line(data:dict): + data["ts"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + fname = os.environ.get("DEBUG_FILE", "timing_debug.jsonl") + with open(fname, "a") as f: + f.write(json.dumps(data, ensure_ascii=False) + "\n") class SlidingWindowData(BaseModel): prompt_tokens_window: list[list[int]] = Field( @@ -56,8 +64,9 @@ class SlidingWindowData(BaseModel): class SlidingWindowAggregator: - def __init__(self, window_size: int): + def __init__(self, window_size: int, min_samples: int = 5): self.window_size = window_size + self.min_samples = min_samples self.data = SlidingWindowData() def update(self, prompt_tokens: list[int], output_tokens: list[int]): @@ -70,8 +79,11 @@ def update(self, prompt_tokens: list[int], output_tokens: list[int]): self.data.timestamps.pop(0) def get_stats(self): - if len(self.data.prompt_tokens_window) < self.window_size: + if len(self.data.prompt_tokens_window) < self.min_samples: + logger.warning("Not enough data to compute sliding stats") return None + elif len(self.data.prompt_tokens_window) < self.window_size: + logger.warning(f"Compute sliding stats over just {len(self.data.prompt_tokens_window)} samples") # 1. How many samples do we produce per second? # 2. How many output tokens do we produce per second? @@ -107,6 +119,10 @@ def make_stats_dict() -> dict: return defaultdict(lambda: defaultdict(list)) +def get_number_of_tokens_in_result(result: RolloutResult) -> int: + return sum(training_text.prompt_tokens + training_text.output_tokens for training_text in result.training_texts) + + async def schedule_rollouts( cfg: DictConfig, attempts: int, @@ -132,10 +148,11 @@ async def schedule_rollouts( active_rollouts = [0] * len(llms) started_rollouts = 0 finished_rollouts = 0 + token_count = 0 # Track rollouts per problem group group_rollouts = {} rollout_policy = hydra.utils.get_method(cfg.actor.rollout_policy) - logger.info(f"Use rollout policy: {rollout_policy}") + logger.info(f"Use rollout policy: {rollout_policy.__name__}") async def rollout_and_maybe_produce_result( problem: dict, @@ -144,13 +161,16 @@ async def rollout_and_maybe_produce_result( llm_index: int, session: aiohttp.ClientSession, ): - nonlocal started_rollouts, finished_rollouts + nonlocal started_rollouts, finished_rollouts, token_count try: llm = llms[llm_index] model_version = trainer_state.propagated_weight_version assert model_version is not None - rollout_result = await rollout_policy(cfg, llm, problem, session) + logger.info(f"Starting rollout policy for problem {problem['id']}") + rollout_result: RolloutResult = await rollout_policy(cfg, llm, problem, session) + logger.info(f"Finished rollout policy for problem {problem['id']}") rollout_result.model_version = model_version + token_count += get_number_of_tokens_in_result(rollout_result) # Make a group id that will be different from groups made by another rollout maker full_group_id = f"{scheduler_name}_{group_id}" rollout_result.group_id = full_group_id @@ -187,15 +207,21 @@ async def rollout_and_maybe_produce_result( logger.info("Starting rollout scheduler") connector = aiohttp.TCPConnector(limit=50000, limit_per_host=50000, keepalive_timeout=1.0) timeout = aiohttp.ClientTimeout(total=3600.0, connect=3600.0, sock_read=3600.0) + old_finished_rollouts = 0 + start_time = time.time() async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session: while True: if time.time() - last_logged > 10.0 and sum(active_rollouts): + if finished_rollouts > old_finished_rollouts: + old_finished_rollouts = finished_rollouts + save_debug_line({"rollouts_finished": finished_rollouts, "tokens_produced": token_count, "dt": time.time() - start_time, "token_speed": token_count / (time.time() - start_time)}) logger.info( f"{scheduler_name}: " f"rollouts in progress: {sum(active_rollouts)}, " f"groups in progress: {len(group_rollouts)}, " f"rollouts started so far: {started_rollouts}, " f"rollouts finished so far: {finished_rollouts}, " + f"groups started so far: {group_id}, " f"max group size in bytes: {result_queue.max_actual_entry_size()}, " ) last_logged = time.time() @@ -217,7 +243,6 @@ async def rollout_and_maybe_produce_result( await asyncio.sleep(0.01) continue active_rollouts[next_llm] += 1 - started_rollouts += 1 assert problem is not None loop.create_task( rollout_and_maybe_produce_result( @@ -228,6 +253,7 @@ async def rollout_and_maybe_produce_result( session=session, ) ) + started_rollouts += 1 group_rollout_index += 1 logger.info("Rollout scheduler finished") @@ -281,40 +307,45 @@ def __init__( self.sliding_aggregator = SlidingWindowAggregator(window_size=cfg.actor.throughput_window_size) self.llms = llms self.loop_start_time = -1 - self.cfg = cfg + self.cfg: DictConfig = cfg self.is_training = is_training self.is_scheduling_paused = False self.debug_mode = bool(cfg.debug.mode) + self.cfg: DictConfig = cfg - # Determine the number of processes to use - num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) - attempts = self.cfg.attempts if is_training else 1 - - # Divide LLMs approximately equally across processes - llm_groups = [[] for _ in range(num_processes)] - for i, llm in enumerate(self.llms): - llm_groups[i % num_processes].append((i, llm)) + self.smm: SharedMemoryManager | None = None + self.problem_queue: SharedMemoryQueue | None = None + self.result_queue: SharedMemoryQueue | None = None + logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop") + def start_backend(self): self.smm = SharedMemoryManager() self.smm.start() - # Use SharedMemoryQueue instead of separate problem_queue, result_queue, and io_buffer - self.problem_queue = SharedMemoryQueue(self.smm, self.cfg.actor.problem_queue_size, cfg.actor.shared_memory_entry_size) - self.result_queue = SharedMemoryQueue(self.smm, self.cfg.actor.result_queue_size, cfg.actor.shared_memory_entry_size) - - logger.info(f"Initialized {'train' if self.is_training else 'test'} actor loop") + self.problem_queue = SharedMemoryQueue(self.smm, self.cfg.actor.problem_queue_size, self.cfg.actor.shared_memory_entry_size) + self.result_queue = SharedMemoryQueue(self.smm, self.cfg.actor.result_queue_size, self.cfg.actor.shared_memory_entry_size) + logger.info(f"Problem queue size: {self.problem_queue.max_size}, result queue size: {self.result_queue.max_size}") logger.info(f"Result queue buffer size: {self.result_queue.get_memory_size() / 2**30} Gb") # Create and start multiple rollout processes + attempts = self.cfg.attempts if self.is_training else 1 + # Determine the number of processes to use + num_processes = min(self.cfg.actor.rollout_workers, len(self.llms)) + + # Divide LLMs approximately equally across processes + llm_groups = [[] for _ in range(num_processes)] + for i, llm in enumerate(self.llms): + llm_groups[i % num_processes].append((i, llm)) + self.rollout_processes = [] for llm_group in llm_groups: assert llm_group llm_idxs = [llm[0] for llm in llm_group] llms = [llm[1] for llm in llm_group] scheduler_name = ( - f"{'train' if is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}" + f"{'train' if self.is_training else 'test'} scheduler for llms {','.join([str(i) for i in llm_idxs])}" ) process = mp.Process( target=rollout_maker_entrypoint, @@ -328,15 +359,15 @@ def init_stats(self): self.latency_list = [] self.model_versions_list = [] self.sliding_stats = defaultdict(list) - + def compute_domain_agnostic_metrics(self, result: RolloutResult) -> Dict[str, float]: metrics = {} - + metrics['overflow'] = all([not training_text.finished for training_text in result.training_texts ]) metrics['num_turns'] = len(result.training_texts) metrics['prompt_tokens'] = [training_text.prompt_tokens for training_text in result.training_texts] metrics['output_tokens'] = [training_text.output_tokens for training_text in result.training_texts] - + return metrics def update_stats(self, rollout_results: List[RolloutResult]): @@ -347,8 +378,10 @@ def update_stats(self, rollout_results: List[RolloutResult]): group_id = result.group_id self.latency_list.append(result.latency) self.model_versions_list.append(result.model_version) - domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) + domain_agnostic_metrics = self.compute_domain_agnostic_metrics(result) all_metrics = result.metrics.model_dump() | domain_agnostic_metrics + all_metrics["used_python"] = int(all_metrics.get("used_python", False)) + all_metrics["used_math_answer"] = int(all_metrics.get("used_math_answer", False)) for k, v in all_metrics.items(): if isinstance(v, list): self.stats[k][dataset_name][group_id] += v @@ -356,7 +389,7 @@ def update_stats(self, rollout_results: List[RolloutResult]): self.stats[k][dataset_name][group_id].append(v) else: raise ValueError(f"Unsupported metric type: {type(v)} for key {k}") - + prompt_length_tokens = [training_text.prompt_tokens for result in rollout_results for training_text in result.training_texts] output_length_tokens = [training_text.output_tokens for result in rollout_results for training_text in result.training_texts] self.sliding_aggregator.update(prompt_length_tokens, output_length_tokens) @@ -364,7 +397,7 @@ def update_stats(self, rollout_results: List[RolloutResult]): if sliding_window_stats is not None: for k, v in sliding_window_stats.items(): self.sliding_stats[k].append(v) - + def run(self, dataset: list[tuple[str, dict]]): @@ -437,13 +470,13 @@ def run(self, dataset: list[tuple[str, dict]]): if not self.is_scheduling_paused: while True: blocked_by_lag = submitted_groups == can_submit_before_update and self.is_training - if not blocked_by_lag and not self.problem_queue.full(): + if not blocked_by_lag and self.have_capacity(): try: try: problem = next(problem_iter) - self.problem_queue.put(problem, block=False) + self.submit_problem(problem) submitted_groups += 1 - except queue.Full: + except queue.Full: assert False, "Problem queue was not full just a moment ago, but now it is full" except StopIteration: break @@ -453,7 +486,7 @@ def run(self, dataset: list[tuple[str, dict]]): # Second, try return a result try: # Directly get the result from the SharedMemoryQueue - rollout_results = self.result_queue.get(block=False) + rollout_results = self.get_new_results() except queue.Empty: continue @@ -462,11 +495,16 @@ def run(self, dataset: list[tuple[str, dict]]): raise rollout_results assert isinstance(rollout_results, list) + if len(rollout_results) == 0: + continue assert isinstance(rollout_results[0], RolloutResult) + assert len(rollout_results) == attempts, ( + f"Expected {attempts} rollouts, got {len(rollout_results)}" + ) group_samples = sum(len(r.training_texts) for r in rollout_results) published_samples += group_samples - samples_in_queue = self.result_queue.qsize() * attempts + samples_in_queue = self.results_ready_to_publish() all_text_dumps = [] for r in rollout_results: for text in r.training_texts: @@ -479,14 +517,13 @@ def run(self, dataset: list[tuple[str, dict]]): f" {in_progress} groups in progress" ) - self.update_stats(rollout_results=rollout_results) finished_groups += 1 time_to_publish_train_stats = ( self.is_training and trainer_version_to_publish is not None - ) or self.debug_mode + ) or self.debug_mode time_to_publish_test_stats = finished_groups == expected_rollouts # Publish stats at every new model version or if all tapes are finished @@ -494,11 +531,12 @@ def run(self, dataset: list[tuple[str, dict]]): if self.is_training: loop_stats = { "published_samples": published_samples, - "problem_queue_size": self.problem_queue.qsize(), - "result_queue_size": self.result_queue.qsize(), + "problem_queue_size": self.problem_queue_size(), + "result_queue_size": self.result_queue_size(), "finished_groups": finished_groups, - "trainer_model_version": trainer_version_to_publish, + "trainer_model_version": trainer_version_to_publish, "time_since_start": time.time() - loop_start_time, + "groups_in_progress": in_progress, } trainer_version_to_publish = None else: @@ -514,6 +552,7 @@ def run(self, dataset: list[tuple[str, dict]]): if finished_groups == expected_rollouts: logger.info(f"Finished {expected_rollouts} rollouts, stopping actor loop") + self.stop_tasks() break def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): @@ -546,14 +585,198 @@ def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): stats |= loop_stats for k, v in self.sliding_stats.items(): stats[k] = sum(v) / len(v) if v else 0 + + rename_suffixes = { + "num_python_calls_mean": "python_calls_mean", + "used_python_mean": "python_usage_rate", + "num_math_answer_calls_mean": "math_answer_calls_mean", + "used_math_answer_mean": "math_answer_usage_rate", + } + + for key in list(stats.keys()): + for old_suffix, new_suffix in rename_suffixes.items(): + if key.endswith(old_suffix): + prefix = key[: -len(old_suffix)] + stats[f"{prefix}{new_suffix}"] = stats[key] + break + + logger.info(f"Publish actor stats to wandb: {stats}") if self.cfg.wandb.use_wandb: wandb.log({f"actor/{k}": v for k, v in stats.items()}) stats_writer.write(stats) self.init_stats() # Reset stats for the next iteration + def have_capacity(self) -> bool: + return not self.problem_queue.full() + + def submit_problem(self, problem: dict): + self.problem_queue.put(problem, block=False) + + def stop_tasks(self): + pass + + def get_new_results(self) -> list[RolloutResult]: + return self.result_queue.get(block=False) + + def results_ready_to_publish(self) -> int: + return self.result_queue_size() * self.cfg.attempts + + def problem_queue_size(self) -> int: + return self.problem_queue.qsize() + + def result_queue_size(self) -> int: + return self.result_queue.qsize() + + +class ActorLoopRay(ActorLoop): + """ + Loop that runs the ray tasks for n_jobs to perform rollouts in parallel + """ + ray_ready: bool = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cfg_dict = OmegaConf.to_container(self.cfg, resolve=True) + self.unfinished_tasks = [] + self.llms_by_url = {llm.get_base_url(): llm for llm in self.llms} + self.llms_utilization = {llm.get_base_url(): 0 for llm in self.llms} + self.scheduler_name = f"{'train' if self.is_training else 'test'} ray scheduler" + self.problem_id = 0 + self.attempts = self.cfg.attempts if self.is_training else 1 + self.unfinished_problems = defaultdict(list) # up to `attempts` rollout results for each problem + self.finished_problems = [] + self.token_count = 0 + self.finished_rollouts_count = 0 + self.task_latencies = [] + self.ray_result_latencies = [] + + def start_backend(self): + if not self.ray_ready: + logger.info(f"Initializing Ray with {self.cfg.actor.rollout_workers} workers..") + ray_context = ray.init(num_cpus=self.cfg.actor.rollout_workers, dashboard_host="0.0.0.0", include_dashboard=True) + logger.info(f"Ray initialized, dashboard at {ray_context.dashboard_url}") + self.ray_ready = True + else: + logger.info("Ray already initialized") + + assert self.trainer_state.propagated_weight_version is not None + rollout_policy: Callable[[DictConfig, TrainableLLM, dict], RolloutResult] = hydra.utils.get_method(self.cfg.actor.rollout_policy) + def rollout_wrapper(cfg: DictConfig, llm: TrainableLLM, problem: dict, problem_id: int) -> RolloutResult: + start_ts = time.monotonic() + rollout_result: RolloutResult = rollout_policy(cfg, llm, problem) + ts = time.monotonic() + logger.info(f"Problem {problem_id} finished in {ts - start_ts:.2f} seconds") + return rollout_result, llm.get_base_url(), problem_id, ts, start_ts + self.ray_remote = ray.remote(rollout_wrapper) + self.start_time = time.time() + + def have_capacity(self) -> bool: + have_capacity = len(self.unfinished_tasks) < self.cfg.actor.problem_queue_size + have_llm_capacity = any(self.llms_utilization[llm_url] < (self.cfg.actor.llm_max_rollouts - self.attempts) for llm_url in self.llms_utilization) + have_capacity = have_capacity and have_llm_capacity + if not have_capacity: + time.sleep(0.1) # sleep for a while to avoid quick loops when no capacity + return have_capacity + + def submit_problem(self, problem: dict): + for attempt_number in range(self.attempts): + llm_url, task_count = min(self.llms_utilization.items(), key=lambda x: x[1]) + logger.info(f"Submitting problem {self.problem_id} attempt {attempt_number}/{self.attempts} to the least busy LLM {llm_url} with {task_count} tasks") + llm = self.llms_by_url[llm_url] + task_ref = self.ray_remote.remote(self.cfg_dict, llm, problem, self.problem_id) + self.llms_utilization[llm_url] += 1 + self.unfinished_tasks.append(task_ref) + self.problem_id += 1 + + def stop_tasks(self): + ray.shutdown() + + def receive_finished_tasks(self): + num_returns = min(100, len(self.unfinished_tasks)) + try: + finished_tasks, unfinished_tasks = ray.wait(self.unfinished_tasks, num_returns=num_returns, timeout=0.1) + except Exception as e: + logger.error(f"Error waiting for finished ray tasks: {e}") + return + if len(finished_tasks) > 0: + logger.info(f"Found {len(finished_tasks)} finished tasks, {len(unfinished_tasks)} unfinished tasks left") + self.unfinished_tasks = unfinished_tasks + dt = time.time() - self.start_time + for finished_task in finished_tasks: + try: + rollout_result, llm_url, problem_id, stop_ts, start_ts = ray.get(finished_task) + rollout_result.model_version = self.trainer_state.propagated_weight_version + full_group_id = f"{self.scheduler_name}_{problem_id}" + rollout_result.group_id = full_group_id + rollout_index = len(self.unfinished_problems[problem_id]) + for step_index, sample in enumerate(rollout_result.training_texts): + # Downstream in the pipeline we'll need these fields in every sample + sample.metadata["model_version"] = rollout_result.model_version + sample.metadata["rollout_index"] = rollout_index + sample.metadata["step_index"] = step_index + sample.group_id = full_group_id + task_dt = stop_ts - start_ts + self.task_latencies.append(task_dt) + outer_ts = time.monotonic() + ray_result_latency = outer_ts - stop_ts + self.ray_result_latencies.append(ray_result_latency) + except Exception as e: + logger.error(f"Error getting finished ray task: {e}") + continue + if self.llms_utilization[llm_url] > 0: + self.llms_utilization[llm_url] -= 1 + else: + logger.warning(f"LLM {llm_url} utilization is 0, but got a result") + self.token_count += get_number_of_tokens_in_result(rollout_result) + self.finished_rollouts_count += 1 + self.unfinished_problems[problem_id].append(rollout_result) + logger.info(f"Problem {problem_id} has {len(self.unfinished_problems[problem_id])} rollout results") + if len(self.unfinished_problems[problem_id]) == self.cfg.attempts: + logger.info(f"Problem {problem_id} group finished") + group = self.unfinished_problems[problem_id] + random.shuffle(group) + self.finished_problems.append(group) + del self.unfinished_problems[problem_id] + logger.info(f"{len(self.finished_problems)} finished problems ready to return") + logger.info( + f"Ray {'train' if self.is_training else 'test'} actor loop: " + f"rollouts in progress: {len(self.unfinished_tasks)}, " + f"problems in progress: {len(self.unfinished_problems)}, " + f"rollouts finished: {self.finished_rollouts_count}, " + f"total tokens: {self.token_count}, " + f"gen speed: {self.token_count / dt:.2f} tokens/sec, " + f"task latency: {np.mean(self.task_latencies[-10:]):.2f} sec, " + f"ray delay: {np.mean(self.ray_result_latencies[-10:]):.4f} sec" + ) + save_debug_line({ + "rollouts_finished": self.finished_rollouts_count, + "rollouts_in_progress": len(self.unfinished_tasks), + "problems_in_progress": len(self.unfinished_problems), + "tokens_produced": self.token_count, + "dt": dt, + "token_speed": self.token_count / dt, + "ray_latency": np.mean(self.ray_result_latencies[-10:]), + "task_latency": np.mean(self.task_latencies[-10:]), + }) + logger.info(f"LLMs utilization: {self.llms_utilization}") + + def get_new_results(self) -> list[list[RolloutResult]]: + self.receive_finished_tasks() + if len(self.finished_problems) > 0: + logger.info(f"have {len(self.finished_problems)} finished problems, pop one") + return self.finished_problems.pop(0) + return [] + + def problem_queue_size(self) -> int: + return len(self.unfinished_tasks) + + def result_queue_size(self) -> int: + return len(self.finished_problems) + def run_actor_loop(cfg: DictConfig): set_streams_backend(**cfg.streams) + actor_loop_class = ActorLoopRay if cfg.use_ray else ActorLoop # set seed for reproducibility (mostly intended for dataset loading) random.seed(cfg.seed) @@ -588,12 +811,19 @@ def run_actor_loop(cfg: DictConfig): actor_model_path = finetune_model_path else: actor_model_path = cfg.model_path - + + # Align client-side context size with vLLM server max_model_len when available + try: + _context_size = int(cfg.vllm_config.vllm_kwargs.max_model_len) + except Exception: + _context_size = 32000 + train_llms = [ TrainableLLM( base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), + context_size=_context_size, parameters=cfg.llm.parameters, use_cache=False, collect_logprobs=True, @@ -606,6 +836,7 @@ def run_actor_loop(cfg: DictConfig): base_url=url, model_name=str(actor_model_path), tokenizer_name=str(actor_model_path), + context_size=_context_size, parameters=cfg.test_llm.parameters, use_cache=False, collect_logprobs=True, @@ -623,13 +854,12 @@ def run_actor_loop(cfg: DictConfig): trainer_state.start_listening() trainer_state.wait_for_model_version() - train_loop = ActorLoop( + train_loop = actor_loop_class( data_stream=data_stream, cfg=cfg, trainer_state=trainer_state, stats_stream=stats_stream, llms=train_llms ) - train_loop_run = train_loop.run( - dataset=train_dataset, - ) - test_loop = ActorLoop( + train_loop.start_backend() + train_loop_run = train_loop.run(dataset=train_dataset) + test_loop = actor_loop_class( data_stream=test_data_stream, cfg=cfg, trainer_state=trainer_state, @@ -658,6 +888,7 @@ def run_actor_loop(cfg: DictConfig): and test_loop_run is None ): logger.info("Create test loop") + test_loop.start_backend() test_loop_run = test_loop.run( dataset=test_dataset, ) diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index e375b6a5..aa75d4ed 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -8,12 +8,16 @@ from tapeagents.core import LLMCall, LLMOutput, Prompt, TokenLogprob from tapeagents.llms.trainable import TrainableLLM -from pipelinerl.finetune.data import MASKED_TOKEN_ID -from pipelinerl.rollouts import TrainingText from pipelinerl.processor_factory import get_processor +from pipelinerl.rollouts import TrainingText logger = logging.getLogger(__name__) +# -100 is the default "ignore_index" in nn.CrossEntropyLoss +# Defined here to avoid importing dependencies from finetune.data +# Do not replace. Import from finetune module breaks ray parallelization! +MASKED_TOKEN_ID = -100 + def extract_images_from_messages(messages: list[dict]) -> list[Image.Image]: """Extract PIL Images from multimodal messages.""" diff --git a/pipelinerl/domains/math/__init__.py b/pipelinerl/domains/math/__init__.py index 9aee0b8f..7a9809b7 100644 --- a/pipelinerl/domains/math/__init__.py +++ b/pipelinerl/domains/math/__init__.py @@ -1,3 +1,3 @@ from .load_datasets import load_datasets -from .rollouts import generate_math_rollout, RewardTable +from .rollouts import generate_math_rollout, RewardTable, get_reward, length_penalty from .verifier_api import MathEnvironment, verify_answer, verify_answer_rpc \ No newline at end of file diff --git a/pipelinerl/domains/math/load_datasets.py b/pipelinerl/domains/math/load_datasets.py index 4b44dfb6..7cbf9c18 100644 --- a/pipelinerl/domains/math/load_datasets.py +++ b/pipelinerl/domains/math/load_datasets.py @@ -170,6 +170,26 @@ def _load_aime_dataset(year: int, upsample_factor: int = 0) -> list[dict]: return add_ids(samples) +def _load_aime_2025_opencompass(upsample_factor: int = 0) -> list[dict]: + configs = ["AIME2025-I", "AIME2025-II"] + dataset_name = "aime_2025" + ("" if upsample_factor > 0 else "_original") + + samples: list[dict] = [] + for config_name in configs: + ds = load_dataset("opencompass/AIME2025", config_name, split="test") + samples.extend([s for s in process_math(ds, dataset_name) if s is not None]) + + original_size = len(samples) + if upsample_factor > 0: + samples *= upsample_factor + + logger.info( + f"Loading aime 2025 (OpenCompass) dataset: {len(samples)} samples" + + (f" (upsampled from {original_size})" if upsample_factor > 0 else "") + ) + return add_ids(samples) + + def _load_amc_dataset(year: int, upsample_factor: int = 0) -> list[dict]: amc_dataset = load_dataset("AI-MO/aimo-validation-amc", split="train", trust_remote_code=True) amc_dataset = amc_dataset.filter(lambda x: str(year) in x["url"]) @@ -335,6 +355,12 @@ def load_datasets(dataset_names: List[str] | str | None, seed: int | None = None if "aime_2024_original" in dataset_names: datasets += _load_aime_dataset(2024) + if "aime_2025" in dataset_names: + datasets += _load_aime_2025_opencompass(upsample_factor=16) + + if "aime_2025_original" in dataset_names: + datasets += _load_aime_2025_opencompass() + if "amc_2022" in dataset_names: # TODO: AMC 2022 is 43 problems, is that to be expected? datasets += _load_amc_dataset(2022, upsample_factor=16) diff --git a/pipelinerl/domains/math/rollouts.py b/pipelinerl/domains/math/rollouts.py index 41a61021..c293b36f 100644 --- a/pipelinerl/domains/math/rollouts.py +++ b/pipelinerl/domains/math/rollouts.py @@ -26,6 +26,28 @@ class RewardTable(BaseModel): correct_answer_finished: float buffer_tokens: int = 0 # 0 means no overlong reward shaping +def get_reward(answer_status: str, finished: bool, reward_table: RewardTable) -> float: + match (answer_status, finished): + case ("wrong", False): + return reward_table.wrong_answer_not_finished + case ("wrong", True): + return reward_table.wrong_answer_finished + case ("no_answer", False): + return reward_table.no_answer_not_finished + case ("no_answer", True): + return reward_table.no_answer_finished + case ("unparsable", False): + return reward_table.unparsable_not_finished + case ("unparsable", True): + return reward_table.unparsable_finished + case ("correct", False): + return reward_table.correct_answer_not_finished + case ("correct", True): + return reward_table.correct_answer_finished + case _: + raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{finished}") + + def length_penalty(max_length: int, sequence_length: int, buffer_tokens: int) -> float: """ Compute the overlong penalty @@ -51,7 +73,7 @@ async def generate_math_rollout( latency = time.time() - time_start assert llm_call.output.content is not None - rewards = RewardTable(**dict(cfg.rewards)) + reward_table = RewardTable(**dict(cfg.rewards)) discount_factor = cfg.actor.discount_factor # math_verify is a fast environment, no support for environment replicas for now @@ -70,30 +92,11 @@ async def generate_math_rollout( trace = make_training_text(llm, llm_call) # Determine reward based on answer status and finished state - match (answer_status, trace.finished): - case ("wrong", False): - reward = rewards.wrong_answer_not_finished - case ("wrong", True): - reward = rewards.wrong_answer_finished - case ("no_answer", False): - reward = rewards.no_answer_not_finished - case ("no_answer", True): - reward = rewards.no_answer_finished - case ("unparsable", False): - reward = rewards.unparsable_not_finished - case ("unparsable", True): - reward = rewards.unparsable_finished - case ("correct", False): - reward = rewards.correct_answer_not_finished - case ("correct", True): - reward = rewards.correct_answer_finished - case _: - raise ValueError(f"Invalid answer_status/finished combination: {answer_status}/{trace.finished}") - + reward = get_reward(answer_status, trace.finished, reward_table) # Apply discount factor based on output length reward *= discount_factor**llm_call.output_length_tokens overlong_penalty = 0 - if rewards.buffer_tokens > 0: + if reward_table.buffer_tokens > 0: overlong_penalty = length_penalty(llm.parameters['max_tokens'], llm_call.output_length_tokens, rewards.buffer_tokens) reward += overlong_penalty trace.reward = reward diff --git a/pipelinerl/domains/mcp/__init__.py b/pipelinerl/domains/mcp/__init__.py new file mode 100644 index 00000000..4557fa53 --- /dev/null +++ b/pipelinerl/domains/mcp/__init__.py @@ -0,0 +1,2 @@ +from .env_server import EmbeddedEnvironmentWorker, EmbeddedMCPEnvironment, MCPEnvironmentServer +from .rollouts import generate_mcp_rollout, generate_mcp_rollout_with_local_env diff --git a/pipelinerl/domains/mcp/env_server.py b/pipelinerl/domains/mcp/env_server.py new file mode 100644 index 00000000..2298e5cd --- /dev/null +++ b/pipelinerl/domains/mcp/env_server.py @@ -0,0 +1,1035 @@ +import asyncio +import atexit +import inspect +import json +import logging +import os +import re +import threading +import time +import traceback +from concurrent.futures import ProcessPoolExecutor +from contextlib import asynccontextmanager +from functools import partial +from typing import Any, AsyncIterator, List + +import multiprocessing + +from fastapi import HTTPException +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from pydantic import BaseModel +from tapeagents.core import Action, Observation +from tapeagents.environment import Environment +from tapeagents.mcp import MCPClient, MCPEnvironment, NoTool +from tapeagents.remote_environment import EnvironmentServer +from tapeagents.tool_calling import FunctionSpec, ToolCallAction, ToolResult, ToolSpec +from mcp.types import CallToolResult, TextContent + +from pipelinerl.domains.math.verifier_api import verify_answer +from pipelinerl.domains.mcp.steps import MathAnswer + +logger = logging.getLogger(__name__) + + +_CONNECTION_ERROR_PATTERNS = ( + "closedresourceerror", + "brokenresourceerror", + "broken pipe", + "connectionreseterror", + "timed out while waiting for response", +) + + +_MCP_WORKER_STATE: dict[str, Any] | None = None + + +def _shutdown_mcp_worker() -> None: + global _MCP_WORKER_STATE + if not _MCP_WORKER_STATE: + return + loop: asyncio.AbstractEventLoop = _MCP_WORKER_STATE["loop"] + client: MCPClient = _MCP_WORKER_STATE["client"] + try: + loop.run_until_complete(client.close()) + except Exception: + logger.warning("Failed to close MCP client in worker", exc_info=True) + finally: + loop.close() + _MCP_WORKER_STATE = None + + +def _initialize_mcp_worker( + config_path: str, + tools_whitelist: list[str] | tuple[str, ...] | None, + use_cache: bool, + read_timeout_seconds: int, +) -> None: + """Initializer for the ProcessPool workers that own MCP runtimes.""" + global _MCP_WORKER_STATE + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = MCPClient( + config_path=config_path, + use_cache=use_cache, + read_timeout_seconds=read_timeout_seconds, + ) + loop.run_until_complete(client.start_servers()) + _MCP_WORKER_STATE = { + "loop": loop, + "client": client, + "tools_whitelist": list(tools_whitelist or []), + } + atexit.register(_shutdown_mcp_worker) + + +def _call_tool_in_worker(tool_name: str, tool_arguments: Any) -> dict[str, Any]: + """Execute an MCP tool call inside a worker process.""" + if not _MCP_WORKER_STATE: + raise RuntimeError("MCP worker not initialized") + loop: asyncio.AbstractEventLoop = _MCP_WORKER_STATE["loop"] + client: MCPClient = _MCP_WORKER_STATE["client"] + whitelist: list[str] = _MCP_WORKER_STATE.get("tools_whitelist", []) + if whitelist and tool_name not in whitelist: + raise NoTool(f"Tool {tool_name} not allowed by whitelist") + result = loop.run_until_complete(client.call_tool(tool_name, tool_arguments)) + return result.model_dump(exclude_none=True) + + +class _RemoteCallError(RuntimeError): + def __init__(self, message: str, details: dict[str, Any] | None = None) -> None: + super().__init__(message) + self.details = details or {} + + +def _invoke_environment_method( + environment: Environment, + method_name: str, + args: tuple[Any, ...], + kwargs: dict[str, Any], + loop: asyncio.AbstractEventLoop, +) -> Any: + attr = getattr(environment, method_name) + if inspect.iscoroutinefunction(attr): + return loop.run_until_complete(attr(*args, **kwargs)) + result = attr(*args, **kwargs) + if inspect.isawaitable(result): + return loop.run_until_complete(result) + return result + + +def _environment_process_main(env_cfg_container: dict[str, Any], conn) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + env_cfg = OmegaConf.create(env_cfg_container) + environment: Environment = instantiate(env_cfg) + except Exception: + conn.send( + ( + "exception", + { + "type": "EnvironmentBootstrapError", + "message": "Failed to instantiate environment", + "traceback": traceback.format_exc(), + }, + ) + ) + conn.close() + loop.close() + return + + async_methods = { + name + for name in ("ainitialize", "areset", "aclose", "astep", "areact") + if hasattr(environment, name) and inspect.iscoroutinefunction(getattr(environment, name)) + } + sync_methods = { + name + for name in ( + "initialize", + "reset", + "close", + "start_task", + "actions", + "tools_description", + "mark_healthy", + "is_healthy", + "step", + "react", + ) + if callable(getattr(environment, name, None)) + } + + conn.send(("capabilities", {"sync": list(sync_methods), "async": list(async_methods)})) + + running = True + while running: + try: + message = conn.recv() + except EOFError: + break + if not isinstance(message, tuple) or len(message) != 3: + continue + command, args, kwargs = message + if command == "__shutdown__": + running = False + conn.send(("ok", None)) + break + try: + result = _invoke_environment_method(environment, command, args, kwargs, loop) + conn.send(("ok", result)) + except Exception as exc: + conn.send( + ( + "exception", + { + "type": exc.__class__.__name__, + "message": str(exc), + "traceback": traceback.format_exc(), + }, + ) + ) + + try: + if "aclose" in async_methods: + loop.run_until_complete(environment.aclose()) + elif "close" in sync_methods: + environment.close() + except Exception: + logger.debug("Failed to close environment during shutdown", exc_info=True) + finally: + conn.close() + loop.close() + + +class _ProcessEnvironmentProxy: + def __init__(self, env_cfg: DictConfig): + self._ctx = multiprocessing.get_context("spawn") + self._parent_conn, child_conn = self._ctx.Pipe() + cfg_container = OmegaConf.to_container(env_cfg, resolve=True) + self._process = self._ctx.Process( + target=_environment_process_main, + args=(cfg_container, child_conn), + ) + self._process.daemon = False + self._process.start() + self._lock = threading.Lock() + self._closed = False + try: + status, payload = self._parent_conn.recv() + except EOFError as error: + raise _RemoteCallError("Environment process terminated prematurely") from error + if status == "exception": + raise _RemoteCallError(payload.get("message", "Environment bootstrap failed"), payload) + if status != "capabilities": + raise _RemoteCallError("Unexpected handshake from environment process") + self._sync_methods = set(payload.get("sync", [])) + self._async_methods = set(payload.get("async", [])) + + def supports_async(self, name: str) -> bool: + return name in self._async_methods + + def supports_sync(self, name: str) -> bool: + return name in self._sync_methods + + def _ensure_alive(self) -> None: + if self._closed: + raise _RemoteCallError("Environment proxy is closed") + if not self._process.is_alive(): + raise _RemoteCallError("Environment process died unexpectedly") + + def _call_remote(self, method: str, *args: Any, **kwargs: Any) -> Any: + self._ensure_alive() + with self._lock: + try: + self._parent_conn.send((method, args, kwargs)) + status, payload = self._parent_conn.recv() + except EOFError as error: + raise _RemoteCallError("Lost connection to environment process") from error + if status == "ok": + return payload + if status == "exception": + raise _RemoteCallError(payload.get("message", "Remote call failed"), payload) + raise _RemoteCallError(f"Unexpected response type: {status}") + + def start_task(self, task: dict) -> dict: + return self._call_remote("start_task", task) + + def actions(self) -> tuple[type[Action], ...]: + return tuple(self._call_remote("actions")) + + def tools_description(self) -> str: + return self._call_remote("tools_description") + + def initialize(self): + if self.supports_sync("initialize"): + return self._call_remote("initialize") + if self.supports_async("ainitialize"): + return self._call_remote("ainitialize") + return None + + async def ainitialize(self) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.initialize) + + def reset(self) -> None: + if self.supports_sync("reset"): + self._call_remote("reset") + elif self.supports_async("areset"): + self._call_remote("areset") + + async def areset(self) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.reset) + + def step(self, action: Action) -> Observation: + if self.supports_sync("step"): + return self._call_remote("step", action) + if self.supports_async("astep"): + return self._call_remote("astep", action) + raise _RemoteCallError("Remote environment does not support step or astep") + + async def astep(self, action: Action) -> Observation: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.step, action) + + def react(self, tape) -> Any: + if self.supports_sync("react"): + return self._call_remote("react", tape) + if self.supports_async("areact"): + return self._call_remote("areact", tape) + raise _RemoteCallError("Remote environment does not support react or areact") + + async def areact(self, tape) -> Any: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.react, tape) + + def mark_healthy(self) -> None: + if self.supports_sync("mark_healthy"): + self._call_remote("mark_healthy") + + def is_healthy(self) -> bool: + if self.supports_sync("is_healthy"): + return bool(self._call_remote("is_healthy")) + return True + + def close(self) -> None: + if self._closed: + return + try: + if self.supports_sync("close"): + self._call_remote("close") + elif self.supports_async("aclose"): + self._call_remote("aclose") + except _RemoteCallError: + logger.debug("Remote close failed", exc_info=True) + finally: + self._shutdown() + + async def aclose(self) -> None: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.close) + + def _shutdown(self) -> None: + if self._closed: + return + try: + with self._lock: + if self._process.is_alive(): + self._parent_conn.send(("__shutdown__", (), {})) + try: + self._parent_conn.recv() + except EOFError: + pass + except Exception: + logger.debug("Failed to send shutdown to environment process", exc_info=True) + finally: + self._parent_conn.close() + self._process.join(timeout=5) + if self._process.is_alive(): + self._process.terminate() + self._closed = True + + def __del__(self) -> None: + try: + self._shutdown() + except Exception: + pass +class EnvironmentServerWithVerifier(EnvironmentServer): + """Environment server that includes the verify_answer endpoint.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.process_pool = ProcessPoolExecutor(max_workers=4) + + def create_app(self): + app = super().create_app() + + class VerifyAnswerRequest(BaseModel): + prediction: str + gold: str + strict: bool = True + max_prediction_length: int = 1000 + + @app.post("/verify_answer") + async def verify_answer_endpoint(request: VerifyAnswerRequest): + try: + # Run verification in the process pool to avoid blocking the main thread + loop = asyncio.get_event_loop() + answer_status = await loop.run_in_executor( + self.process_pool, + partial( + verify_answer, + request.prediction, + request.gold, + request.strict, + request.max_prediction_length + ) + ) + return {"answer_status": answer_status} + except Exception as e: + logger.exception(f"Error in verify_answer: {e}") + raise HTTPException(status_code=500, detail=f"Error verifying answer: {str(e)}") + + return app + + def shutdown(self): + super().shutdown() + if hasattr(self, 'process_pool'): + self.process_pool.shutdown(wait=True) + + +class MCPEnvironmentServer: + + def __init__(self, + n_envs: int, + host: str, + mcp_target: str, + mcp_config_path: str, + mcp_tools_whitelist: List[str], + exp_path: str, + env_call_timeout: int = 60, + mcp_read_timeout_seconds: int = 10, + ): + # Remote environment server configuration + self.n_envs = n_envs + self.host = host + self.env_call_timeout = env_call_timeout + # Individual web environment configuration + self.mcp_target = mcp_target + self.mcp_config_path = mcp_config_path + self.mcp_tools_whitelist = mcp_tools_whitelist + self.exp_path = exp_path + self.mcp_read_timeout_seconds = mcp_read_timeout_seconds + + + def launch(self, port: int): + """ + Serve the environment in TapeAgent with verify_answer endpoint. + """ + env_server = EnvironmentServerWithVerifier( + n_envs=self.n_envs, + host=self.host, + port=port, + env_call_timeout=self.env_call_timeout + ) + env_server.launch(OmegaConf.create({ + "_target_": self.mcp_target, + "config_path": self.mcp_config_path, + "tools_whitelist": self.mcp_tools_whitelist, + "read_timeout_seconds": self.mcp_read_timeout_seconds, + })) + + +class EmbeddedMCPEnvironment(MCPEnvironment): + def __init__( + self, + *args, + math_answer_description: str = "Submit the final answer in LaTeX \\boxed{} format.", + **kwargs, + ) -> None: + config_path = kwargs.get("config_path", "") + use_cache = kwargs.get("use_cache", False) + read_timeout_seconds = kwargs.get("read_timeout_seconds", 10) + runtime_pool_workers = kwargs.pop("runtime_pool_workers", 0) + offload_tools = tuple(kwargs.pop("offload_tools", ())) + + super().__init__(*args, **kwargs) + self._broken = False + self._last_failure_reason: str | None = None + self._runtime_guard_installed: bool = False + self._runtime_pool: ProcessPoolExecutor | None = None + self._runtime_pool_lock = threading.Lock() + self._runtime_pool_workers = runtime_pool_workers + self._offload_tools = set(offload_tools) + self._config_path = getattr(self.client, "config_path", config_path) + self._use_cache = getattr(self.client, "use_cache", use_cache) + self._read_timeout_seconds = getattr(self.client, "read_timeout_seconds", read_timeout_seconds) + + # try to catch time wasting patterns before execution + self._python_blocklist = ( + (re.compile(r"\bsys\s*\.\s*exit\s*\(", re.IGNORECASE), "sys.exit"), + (re.compile(r"\bos\s*\.\s*_exit\s*\(", re.IGNORECASE), "os._exit"), + (re.compile(r"\bexit\s*\(", re.IGNORECASE), "exit"), + (re.compile(r"\bquit\s*\(", re.IGNORECASE), "quit"), + (re.compile(r"raise\s+systemexit", re.IGNORECASE), "raise SystemExit"), + (re.compile(r"from\s+sys\s+import\s+exit", re.IGNORECASE), "from sys import exit"), + ( + re.compile(r"__import__\s*\(\s*['\"]os['\"]\s*\)\s*\.\s*_exit", re.IGNORECASE), + "__import__('os')._exit", + ), + ( + re.compile(r"__import__\s*\(\s*['\"]sys['\"]\s*\)\s*\.\s*exit", re.IGNORECASE), + "__import__('sys').exit", + ), + ) + self._math_answer_spec = ToolSpec( + function=FunctionSpec( + name="MathAnswer", + description=math_answer_description, + parameters={ + "type": "object", + "properties": { + "answer": { + "type": "string", + "description": "Final answer expressed in LaTeX \\boxed{} format.", + } + }, + "required": ["answer"], + }, + ) + ) + + def initialize(self): + super().initialize() + self._reset_health() + self._ensure_math_answer_tool() + + async def ainitialize(self) -> None: + self.loop = asyncio.get_running_loop() + await super().ainitialize() + self._reset_health() + self._ensure_math_answer_tool() + await self._install_runtime_guard() + + def actions(self): + base_actions = super().actions() + if not any( + getattr(action, "function", None) and action.function.name == "MathAnswer" + for action in base_actions + ): + base_actions = base_actions + (self._math_answer_spec,) + return base_actions + + def _should_offload(self, tool_name: str) -> bool: + return bool(self._runtime_pool_workers) and tool_name in self._offload_tools + + def _ensure_runtime_pool(self) -> ProcessPoolExecutor: + if self._runtime_pool is not None: + return self._runtime_pool + with self._runtime_pool_lock: + if self._runtime_pool is not None: + return self._runtime_pool + cpu_count = os.cpu_count() or 1 + default_workers = max(1, cpu_count // 2) + max_workers = self._runtime_pool_workers or default_workers + whitelist = tuple(self.tools_whitelist) if getattr(self, "tools_whitelist", None) else tuple() + self._runtime_pool = ProcessPoolExecutor( + max_workers=max_workers, + initializer=_initialize_mcp_worker, + initargs=( + self._config_path, + whitelist, + bool(self._use_cache), + int(self._read_timeout_seconds), + ), + ) + return self._runtime_pool + + @staticmethod + def _make_error_call_result(tool_name: str, message: str) -> CallToolResult: + return CallToolResult( + content=[TextContent(type="text", text=message)], + isError=True, + ) + + def _resolve_pool_future_sync(self, future, tool_name: str) -> CallToolResult: + try: + payload = future.result() + return CallToolResult.model_validate(payload) + except NoTool: + logger.exception(f"Tool {tool_name} not found in MCP client") + return self._make_error_call_result(tool_name, f"Tool {tool_name} not found") + except KeyError as error: + logger.exception(f"KeyError when executing MCP tool call: {error}") + return self._make_error_call_result( + tool_name, f"Error executing tool {tool_name}: KeyError {error}" + ) + except Exception as error: + logger.exception(f"Error executing MCP tool call: {error}") + return self._make_error_call_result( + tool_name, f"Error executing tool {tool_name}: {error}" + ) + + async def _resolve_pool_future_async(self, future, tool_name: str) -> CallToolResult: + try: + payload = await asyncio.wrap_future(future) + return CallToolResult.model_validate(payload) + except NoTool: + logger.exception(f"Tool {tool_name} not found in MCP client") + return self._make_error_call_result(tool_name, f"Tool {tool_name} not found") + except KeyError as error: + logger.exception(f"KeyError when executing MCP tool call: {error}") + return self._make_error_call_result( + tool_name, f"Error executing tool {tool_name}: KeyError {error}" + ) + except Exception as error: + logger.exception(f"Error executing MCP tool call: {error}") + return self._make_error_call_result( + tool_name, f"Error executing tool {tool_name}: {error}" + ) + + def _shutdown_runtime_pool(self) -> None: + if self._runtime_pool is not None: + self._runtime_pool.shutdown(wait=True) + self._runtime_pool = None + + def _execute_tool_via_pool_sync(self, action: ToolCallAction) -> ToolResult: + start = time.perf_counter() + future = self._ensure_runtime_pool().submit( + _call_tool_in_worker, + action.function.name, + action.function.arguments, + ) + call_result = self._resolve_pool_future_sync(future, action.function.name) + observation = ToolResult(tool_call_id=getattr(action, "id", ""), content=call_result) + observation.metadata.other["action_execution_time"] = time.perf_counter() - start + observation.metadata.other["action_kind"] = action.kind + return observation + + async def _execute_tool_via_pool_async(self, action: ToolCallAction) -> ToolResult: + start = time.perf_counter() + future = self._ensure_runtime_pool().submit( + _call_tool_in_worker, + action.function.name, + action.function.arguments, + ) + call_result = await self._resolve_pool_future_async(future, action.function.name) + observation = ToolResult(tool_call_id=getattr(action, "id", ""), content=call_result) + observation.metadata.other["action_execution_time"] = time.perf_counter() - start + observation.metadata.other["action_kind"] = action.kind + return observation + + def step(self, action: Action) -> Observation: + if not isinstance(action, ToolCallAction): + return super().step(action) + + outcome, message = self._precheck_tool_action(action) + if outcome == "math_answer": + return self._create_math_answer(action) + if outcome == "error": + return self._make_error_tool_result(action, message or "") + + try: + observation = self._execute_tool_call_sync(action) + except BaseException: + self._broken = True + raise + + return self._postprocess_after_tool(action, observation) + + async def astep(self, action: Action) -> Observation: + if not isinstance(action, ToolCallAction): + return await super().astep(action) + + outcome, message = self._precheck_tool_action(action) + if outcome == "math_answer": + return self._create_math_answer(action) + if outcome == "error": + return self._make_error_tool_result(action, message or "") + + try: + observation = await self._execute_tool_call_async(action) + except BaseException: + self._broken = True + raise + + return self._postprocess_after_tool(action, observation) + + def _precheck_tool_action(self, action: ToolCallAction) -> tuple[str, str | None]: + if action.function.name == "MathAnswer": + return "math_answer", None + if self._broken: + return "error", self._backend_unavailable_message() + if action.function.name == "run_python_code": + block_message = self._check_python_safety(action.function.arguments) + if block_message is not None: + return "error", block_message + return "ok", None + + def _execute_tool_call_sync(self, action: ToolCallAction) -> Observation: + if self._should_offload(action.function.name): + return self._execute_tool_via_pool_sync(action) + return super().step(action) + + async def _execute_tool_call_async(self, action: ToolCallAction) -> Observation: + if self._should_offload(action.function.name): + return await self._execute_tool_via_pool_async(action) + return await super().astep(action) + + def _postprocess_after_tool( + self, + action: ToolCallAction, + observation: Observation, + ) -> Observation: + if action.function.name != "MathAnswer": + return self._postprocess_tool_observation(action, observation) + return observation + + def _ensure_math_answer_tool(self) -> None: + if not any( + getattr(tool, "function", None) and tool.function.name == "MathAnswer" + for tool in self.tools + ): + self.tools.append(self._math_answer_spec) + + def _reset_health(self) -> None: + self._broken = False + self._last_failure_reason = None + self._runtime_guard_installed = False + + def _create_math_answer(self, action: ToolCallAction) -> MathAnswer: + answer_value = self._extract_answer(action.function.arguments) + math_answer = MathAnswer(answer=answer_value) + math_answer.metadata.other.update({ + "action_kind": "MathAnswer", + "tool_call_id": getattr(action, "id", ""), + "action_execution_time": 0.0, + }) + return math_answer + + def mark_healthy(self) -> None: + self._reset_health() + + def is_healthy(self) -> bool: + return not self._broken + + def close(self) -> None: + self._shutdown_runtime_pool() + super().close() + + async def aclose(self) -> None: + self._shutdown_runtime_pool() + await super().aclose() + + @staticmethod + def _guard_snippet() -> str: + """generate Python code that installs safety guards""" + return ( + "import builtins, sys, os, time, atexit\n" + "try:\n" + " _PIPELINERL_TIME_LIMIT = float(os.environ.get('PIPELINERL_PY_TIMEOUT', '30'))\n" + "except ValueError:\n" + " _PIPELINERL_TIME_LIMIT = 30.0\n" + "_PIPELINERL_START = time.perf_counter()\n" + "class _ExitBlocked(RuntimeError):\n" + " pass\n" + "def _blocked_exit(*_args, **_kwargs):\n" + " raise _ExitBlocked('exit() and os._exit() are disabled in this environment.')\n" + "for _target in (builtins, sys):\n" + " for _name in ('exit', 'quit'):\n" + " if hasattr(_target, _name):\n" + " setattr(_target, _name, _blocked_exit)\n" + "if hasattr(os, '_exit'):\n" + " os._exit = _blocked_exit\n" + "def _pipelinerl_trace(frame, event, arg):\n" + " if event == 'line' and (time.perf_counter() - _PIPELINERL_START) > _PIPELINERL_TIME_LIMIT:\n" + " sys.settrace(None)\n" + " raise RuntimeError(f'Python execution timed out after {_PIPELINERL_TIME_LIMIT} seconds.')\n" + " return _pipelinerl_trace\n" + "sys.settrace(_pipelinerl_trace)\n" + "atexit.register(lambda: sys.settrace(None))\n" + ) + + async def _install_runtime_guard(self) -> None: + """Install runtime safety guard in the Python environment.""" + if self._runtime_guard_installed or not getattr(self, "client", None): + return + try: + snippet = self._guard_snippet() + if self._should_offload("run_python_code"): + future = self._ensure_runtime_pool().submit( + _call_tool_in_worker, + "run_python_code", + {"python_code": snippet}, + ) + await self._resolve_pool_future_async(future, "run_python_code") + else: + await self.client.call_tool( + "run_python_code", + {"python_code": snippet}, + ) + self._runtime_guard_installed = True + logger.debug("Runtime guard installed successfully") + except Exception: + logger.warning("Failed to install runtime guard in MCP environment", exc_info=True) + + def _postprocess_tool_observation( + self, + action: ToolCallAction, + observation: Observation, + ) -> Observation: + if not isinstance(observation, ToolResult): + return observation + call_result = observation.content + if not isinstance(call_result, CallToolResult): + return observation + if not getattr(call_result, "isError", False): + return observation + error_text = self._extract_call_result_text(call_result) + if not self._is_connection_error_message(error_text): + return observation + logger.warning( + "MCP backend failure detected for tool %s: %s", + action.function.name, + error_text, + ) + return self._handle_connection_failure(action, observation, error_text) + + @staticmethod + def _extract_call_result_text(call_result: CallToolResult) -> str: + if not isinstance(call_result.content, list): + return "" + parts: list[str] = [] + for block in call_result.content: + if isinstance(block, TextContent) and isinstance(block.text, str): + parts.append(block.text) + return "\n".join(parts).strip() + + @staticmethod + def _is_connection_error_message(message: str) -> bool: + lowered = message.lower() + return any(pattern in lowered for pattern in _CONNECTION_ERROR_PATTERNS) + + def _handle_connection_failure( + self, + action: ToolCallAction, + observation: ToolResult, + error_text: str, + ) -> ToolResult: + """Mark environment as broken and update observation.""" + self._broken = True + failure_message = ( + "Python tool backend became unavailable (connection lost). " + "Environment will restart after this attempt; stop issuing additional tool calls." + ) + if error_text: + failure_message = f"{failure_message}\nOriginal error: {error_text}" + + observation.content = CallToolResult( + content=[TextContent(type="text", text=failure_message)], + isError=True, + ) + observation.metadata.other.setdefault("action_execution_time", observation.metadata.other.get("action_execution_time", 0.0)) + observation.metadata.other["connection_failure"] = True + observation.metadata.other["original_error"] = error_text + self._last_failure_reason = failure_message + return observation + + def _backend_unavailable_message(self) -> str: + """Get message for unavailable backend.""" + return self._last_failure_reason or ( + "Python tool backend is restarting after a connection failure. " + "Abort this attempt and wait for a fresh environment." + ) + + @staticmethod + def _extract_answer(arguments: dict | str | None) -> str: + """Extract answer string from arguments.""" + if arguments is None: + return "" + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + return str(parsed.get("answer", "")) if isinstance(parsed, dict) else str(parsed) + except json.JSONDecodeError: + return arguments + if isinstance(arguments, dict): + return str(arguments.get("answer", "")) + return str(arguments) + + def _check_python_safety(self, arguments: dict | str | None) -> str | None: + """check for Python code problems""" + code = self._extract_python_code(arguments) + if not code: + return None + for pattern, label in self._python_blocklist: + if pattern.search(code): + return ( + f"Python execution rejected: forbidden call detected ({label}). " + "Use pure computation without exiting the runtime." + ) + return None + + @staticmethod + def _extract_python_code(arguments: dict | str | None) -> str: + if arguments is None: + return "" + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + if isinstance(parsed, dict): + return str(parsed.get("python_code", parsed.get("code", ""))) + return str(parsed) + except json.JSONDecodeError: + return arguments + if isinstance(arguments, dict): + return str(arguments.get("python_code", arguments.get("code", ""))) + return str(arguments) + + def _make_error_tool_result(self, action: ToolCallAction, message: str) -> ToolResult: + result = CallToolResult( + content=[TextContent(type="text", text=message)], + isError=True, + ) + tool_result = ToolResult( + tool_call_id=getattr(action, "id", ""), + content=result, + ) + tool_result.metadata.other["action_execution_time"] = 0.0 + tool_result.metadata.other["action_kind"] = action.kind + return tool_result + + +class EmbeddedEnvironmentWorker: + def __init__(self, env_cfg: DictConfig, concurrency: int = 1): + # make repeated instantiations stable even if the caller changes its copy + self._env_cfg = OmegaConf.create(env_cfg) + self._cfg_signature = self._make_cfg_signature(self._env_cfg) + self._concurrency = max(1, concurrency) + self._init_lock = asyncio.Lock() + self._available: asyncio.Queue[_ProcessEnvironmentProxy] | None = None + self._all_envs: set[_ProcessEnvironmentProxy] = set() + + @staticmethod + def _make_cfg_signature(cfg: DictConfig) -> str: + try: + container = OmegaConf.to_container(cfg, resolve=True) + except Exception: + container = OmegaConf.to_container(cfg, resolve=False) + return json.dumps(container, sort_keys=True, default=str) + + @property + def concurrency(self) -> int: + return self._concurrency + + def matches(self, env_cfg: DictConfig) -> bool: + return self._cfg_signature == self._make_cfg_signature(env_cfg) + + def set_concurrency(self, concurrency: int) -> None: + self._concurrency = max(1, concurrency) + + async def _ensure_pool(self) -> None: + if self._available is None: + self._available = asyncio.Queue() + if len(self._all_envs) >= self._concurrency: + return + async with self._init_lock: + if len(self._all_envs) >= self._concurrency: + return + missing = self._concurrency - len(self._all_envs) + for _ in range(missing): + environment = _ProcessEnvironmentProxy(self._env_cfg) + try: + await self._init_and_reset(environment) + except Exception: + logger.exception("Failed to initialize embedded environment instance") + await self._close(environment) + raise + self._all_envs.add(environment) + await self._available.put(environment) + + @asynccontextmanager + async def alifecycle(self) -> AsyncIterator[Environment]: + """Context manager for environment lifecycle with automatic health checking.""" + await self._ensure_pool() + assert self._available is not None + + environment = await self._available.get() + try: + await self._reset(environment) + yield environment + finally: + try: + unhealthy = ( + hasattr(environment, "is_healthy") + and not environment.is_healthy() # type: ignore + ) + except Exception: + logger.warning("Failed to query embedded environment health; replacing", exc_info=True) + unhealthy = True + is_healthy = not unhealthy + + if is_healthy: + # try to reset and recycle healthy environment + try: + await self._reset(environment) + if hasattr(environment, "mark_healthy"): + environment.mark_healthy() # type: ignore + await self._available.put(environment) + except Exception: + logger.exception("Failed to recycle embedded environment; replacing") + await self._replace(environment) + else: + # environment is unhealthy, replace it + logger.warning("Embedded environment is unhealthy, replacing") + await self._replace(environment) + + async def _replace(self, environment: Environment) -> None: + """Replace a broken environment with a new one.""" + if environment in self._all_envs: + self._all_envs.remove(environment) + try: + await self._close(environment) + except Exception: + logger.exception("Failed to close environment during replacement") + # Refill the pool + await self._ensure_pool() + + async def _init_and_reset(self, env: Environment) -> None: + # init + if hasattr(env, "ainitialize") and inspect.iscoroutinefunction(env.ainitialize): + await env.ainitialize() # type: ignore + else: + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, env.initialize) + + # reset + await self._reset(env) + + async def _reset(self, env: Environment) -> None: + if hasattr(env, "areset") and inspect.iscoroutinefunction(env.areset): + await env.areset() # type: ignore + else: + reset_fn = getattr(env, "reset", None) + if callable(reset_fn): + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, reset_fn) + + async def _close(self, env: Environment) -> None: + loop = asyncio.get_running_loop() + + # try async close first + if hasattr(env, "aclose") and inspect.iscoroutinefunction(env.aclose): + try: + await env.aclose() # type: ignore + return + except Exception as e: + logger.debug(f"Async close failed: {e}, trying sync close") + + # fallback to sync close + try: + await loop.run_in_executor(None, env.close) + except Exception as e: + logger.debug(f"Sync close failed: {e}") diff --git a/pipelinerl/domains/mcp/rollouts.py b/pipelinerl/domains/mcp/rollouts.py new file mode 100644 index 00000000..782d4978 --- /dev/null +++ b/pipelinerl/domains/mcp/rollouts.py @@ -0,0 +1,400 @@ +import asyncio +import json +import logging +import random +import time +from collections import Counter +from typing import Dict, List +from urllib.parse import urlparse + +import aiohttp +from hydra.utils import instantiate +from omegaconf import DictConfig, OmegaConf +from tapeagents.agent import DEFAULT, Agent +from tapeagents.core import LLMCall, Tape, TrainingText +from tapeagents.dialog_tape import UserStep +from tapeagents.llms.trainable import TrainableLLM +from tapeagents.mcp import MCPEnvironment +from tapeagents.orchestrator import async_execute_agent, execute_agent, get_agent_and_env_from_config +from tapeagents.remote_environment import AsyncRemoteEnvironment + +from pipelinerl.async_llm import make_training_text +from pipelinerl.domains.mcp.env_server import EmbeddedEnvironmentWorker +from pipelinerl.domains.mcp.steps import MathAnswer +from pipelinerl.world import Job +from pipelinerl.domains.math import RewardTable, get_reward, verify_answer, verify_answer_rpc, length_penalty +from pipelinerl.rollouts import RolloutResult, BaseMetrics + +logger = logging.getLogger(__name__) + + +_embedded_worker: EmbeddedEnvironmentWorker | None = None + +class FailedRollout(Exception): + pass + +def _get_embedded_worker(env_cfg: DictConfig, concurrency: int) -> EmbeddedEnvironmentWorker: + global _embedded_worker + concurrency = max(1, concurrency) + if _embedded_worker is None or not _embedded_worker.matches(env_cfg): + _embedded_worker = EmbeddedEnvironmentWorker(env_cfg, concurrency=concurrency) + else: + _embedded_worker.set_concurrency(concurrency) + return _embedded_worker + +def count_tool_calls_by_category(llm_calls: List[LLMCall]) -> Dict[str, int]: + """ + Count the number of tool calls for each function name category. + + Args: + llm_calls: List of LLMCall objects + + Returns: + Dictionary mapping function names to their counts + """ + tool_call_names = [] + + for llm_call in llm_calls: + if llm_call.output.tool_calls: + for tool_call in llm_call.output.tool_calls: + tool_call_names.append(tool_call.function.name) + + return dict(Counter(tool_call_names)) + + +class Metrics(BaseMetrics): + num_python_calls: int = 0 + num_steps: int = 0 + n_llm_calls: int = 0 + total_execution_time: float = -1.0 + agent_execution_time: float = -1.0 + environment_execution_time: float = -1.0 + overflow: bool = False + +async def generate_mcp_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + start = time.perf_counter() + + chosen_url: str | None = None + env_host: str | None = None + env_port: int | None = None + + if cfg.world.environment_mode == "remote": + env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] + if not env_jobs: + raise RuntimeError("No environment servers available") + + env_urls_all = [f"http://{job.hostname}:{job.port}" for job in env_jobs if job.port is not None] + if not env_urls_all: + raise RuntimeError("Environment server definitions missing ports") + + while True: + env_urls = env_urls_all[:] + random.shuffle(env_urls) + chosen_url = None + for env_url in env_urls: + jitter = random.randint(3, 12) + try: + environment = AsyncRemoteEnvironment( + server_url=env_url, start_timeout_sec=600, start_repeat_delay=jitter) + context_manager = environment.acontext(session, wait_for_env=True) + env = await context_manager.__aenter__() + try: + await env.start_task(problem) + chosen_url = env_url + actions = await env.a_actions() + tools_description = await env.a_tools_description() + logger.debug(f"Available tools: {tools_description}") + agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) + agent.llms = {DEFAULT: llm} + + tape = Tape(steps=[ + UserStep(content=f"{problem['task']}. You have access to the following tools: {tools_description}") + ]) + t_exec = time.perf_counter() + while True: + try: + tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t_exec}) + break + except Exception: + await asyncio.sleep(5) + break # success + finally: + await context_manager.__aexit__(None, None, None) + except Exception as e: + logger.warning(f"Env start failed at {env_url}: {e}") + continue + if chosen_url is not None: + break # success + await asyncio.sleep(1.0) + + parsed = urlparse(chosen_url) + env_host, env_port = parsed.hostname, parsed.port + else: + concurrency = max(1, int(getattr(cfg.world, "env_replicas_per_actor", 1))) + env_worker = _get_embedded_worker(cfg.environment, concurrency) + async with env_worker.alifecycle() as environment: + start_result = environment.start_task(problem) + tape_metadata = start_result if isinstance(start_result, dict) else {} + + actions = environment.actions() + tools_description = environment.tools_description() + logger.debug(f"Embedded tools: {tools_description}") + agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) + agent.llms = {DEFAULT: llm} + tape = Tape( + steps=[ + UserStep( + content=f"{problem['task']}. You have access to the following tools: {tools_description}" + ) + ] + ) + if tape_metadata: + tape.metadata.other.update(tape_metadata) + + t_exec = time.perf_counter() + tape = await async_execute_agent(agent, tape, environment, session, max_loops=cfg.agent_max_loops) + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t_exec}) + env_host = env_port = None + + reward_table = RewardTable(**dict(cfg.rewards)) + + llm_calls: list[LLMCall] = [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in tape.steps if step.metadata.other.get("llm_call") is not None + ] + assert len(llm_calls) > 0, "No LLM calls found" + tool_call_counts = count_tool_calls_by_category(llm_calls) + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + n_llm_calls = len(llm_calls) + if env_host and env_port: + answer_status = await verify_answer_rpc( + session=session, + host=env_host, + port=env_port, + prediction=llm_calls[-1].output.content, # type: ignore + gold=problem["answer"], + strict=True, + ) + else: + answer_status = verify_answer( + prediction=llm_calls[-1].output.content, # type: ignore + gold=problem["answer"], + strict=True, + ) + # Tape should finish with an answer + tape_finished = True if isinstance(tape.steps[-1], MathAnswer) else False + base_reward = get_reward(answer_status, tape_finished, reward_table) + + reward = base_reward + + discount_factor = float(getattr(cfg.actor, "discount_factor", 1.0)) + if discount_factor != 1.0: + total_generated_tokens = sum(getattr(call, "output_length_tokens", 0) for call in llm_calls) + reward *= discount_factor ** total_generated_tokens + + buffer_tokens = getattr(reward_table, "buffer_tokens", 0) + if buffer_tokens: + max_tokens = int(llm.parameters.get("max_tokens", 0)) + total_output_tokens = sum(getattr(text, "output_tokens", 0) for text in training_texts) + if max_tokens > 0: + reward += length_penalty(max_tokens, total_output_tokens, buffer_tokens) + + # Assign identical reward to all steps in the rollout (pipeline expects uniform rollout_reward) + for text in training_texts: + text.reward = reward + text.finished = tape_finished + + latency = time.perf_counter() - start + + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + total_time = tape.metadata.result.get("total_execution_time", -1.0) + + + metrics = Metrics( + reward=reward, + success=answer_status == "correct", + no_error=answer_status != "unparsable", + no_answer=answer_status == "no_answer", + num_steps=len(tape.steps), + num_python_calls=tool_call_counts.get("run_python_code", 0), + n_llm_calls=n_llm_calls, + total_execution_time=total_time, + agent_execution_time=agent_time, + environment_execution_time=env_time, + overflow=not tape_finished, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + ) + + + +def generate_mcp_rollout_with_local_env( + cfg: DictConfig | dict, + llm: TrainableLLM, + problem: dict, +) -> RolloutResult: + start = time.perf_counter() + if isinstance(cfg, dict): + cfg = OmegaConf.create(cfg) + agent, _env = get_agent_and_env_from_config(cfg) + environment: MCPEnvironment = _env + logger.info(f"Agent and environment loaded, using llm {llm.model_name} at {llm.get_base_url()}") + try: + t_exec = time.perf_counter() + start_result = environment.start_task(problem) + logger.info("Task started") + tape_metadata = start_result if isinstance(start_result, dict) else {} + agent.llms = {DEFAULT: llm} + tape = Tape( + steps=[ + UserStep( + content=f"{problem['task']}. You have access to the following tools: {environment.tools_description()}" + ) + ] + ) + if tape_metadata: + tape.metadata.other.update(tape_metadata) + + logger.info("Running agent..") + tape = execute_agent(agent, tape, environment, max_loops=cfg.agent_max_loops) + logger.info("Agent finished") + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t_exec}) + reward_table = RewardTable(**dict(cfg.rewards)) + + llm_calls: list[LLMCall] = [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in tape.steps if step.metadata.other.get("llm_call") is not None + ] + assert len(llm_calls) > 0, "No LLM calls found" + tool_call_counts = count_tool_calls_by_category(llm_calls) + logger.info(f'Use {type(llm)} LLM to generate training texts') + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + n_llm_calls = len(llm_calls) + answer_status = verify_answer( + prediction=llm_calls[-1].output.content, # type: ignore + gold=problem["answer"], + strict=True, + ) + # Tape should finish with an answer + tape_finished = True if isinstance(tape.steps[-1], MathAnswer) else False + base_reward = get_reward(answer_status, tape_finished, reward_table) + + # Local reward shaping (configurable in conf/mcp.yaml) + total_shaping = 0.0 + shaping_cfg = getattr(cfg, "python_tool_shaping", None) + if shaping_cfg is not None: + num_python_calls = tool_call_counts.get("run_python_code", 0) + bonus_on_correct_with_python = float(getattr(shaping_cfg, "bonus_on_correct_with_python", 0.0)) + penalty_on_incorrect_without_python = float(getattr(shaping_cfg, "penalty_on_incorrect_without_python", 0.0)) + max_abs = float(getattr(shaping_cfg, "max_abs", 0.2)) + + # Episode-level bonuses/penalties + if answer_status == "correct" and num_python_calls >= 1: + total_shaping += bonus_on_correct_with_python + if answer_status in ("wrong", "unparsable") and num_python_calls == 0: + total_shaping -= penalty_on_incorrect_without_python + + # Clamp total shaping + if total_shaping > max_abs: + total_shaping = max_abs + if total_shaping < -max_abs: + total_shaping = -max_abs + + # Length shaping: discourage very long completions; award concise correct ones + length_cfg = getattr(cfg, "length_shaping", None) + if length_cfg is not None: + try: + # Prefer ratio-based target if provided; otherwise use absolute + if hasattr(length_cfg, "target_ratio"): + ratio = float(getattr(length_cfg, "target_ratio")) + max_gen = int(llm.parameters.get("max_tokens", 2048)) + target_tokens = int(max(1, ratio * max_gen)) + # Optional clamps + min_t = int(getattr(length_cfg, "min_target_tokens", 0)) + max_t = int(getattr(length_cfg, "max_target_tokens", 10**9)) + target_tokens = max(min_t, min(max_t, target_tokens)) + else: + target_tokens = int(getattr(length_cfg, "target_output_tokens", 512)) + slope = float(getattr(length_cfg, "slope", 0.0)) + max_penalty = float(getattr(length_cfg, "max_penalty", 0.0)) + bonus_short_correct = float(getattr(length_cfg, "bonus_on_short_correct", 0.0)) + except Exception: + target_tokens, slope, max_penalty, bonus_short_correct = 512, 0.0, 0.0, 0.0 + + # average output tokens across llm calls for this rollout + try: + avg_output_tokens = sum(t.output_tokens for t in training_texts) / max(1, len(training_texts)) + except Exception: + avg_output_tokens = 0.0 + + if slope > 0.0 and max_penalty > 0.0 and avg_output_tokens > target_tokens: + over_by = float(avg_output_tokens - target_tokens) + penalty = min(max_penalty, slope * over_by) + total_shaping -= penalty + + if bonus_short_correct > 0.0 and answer_status == "correct" and avg_output_tokens <= target_tokens: + total_shaping += bonus_short_correct + + reward = base_reward + total_shaping + + # Assign identical reward to all steps in the rollout (pipeline expects uniform rollout_reward) + for text in training_texts: + # debug_save_training_text(text) + text.reward = reward + text.finished = tape_finished + + latency = time.perf_counter() - start + + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + total_time = tape.metadata.result.get("total_execution_time", -1.0) + + metrics = Metrics( + reward=reward, + success=answer_status == "correct", + no_error=answer_status != "unparsable", + no_answer=answer_status == "no_answer", + num_steps=len(tape.steps), + num_python_calls=tool_call_counts.get("run_python_code", 0), + n_llm_calls=n_llm_calls, + total_execution_time=total_time, + agent_execution_time=agent_time, + environment_execution_time=env_time, + overflow=not tape_finished, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"] + ) + except Exception as e: + err_msg = f"Error generating rollout: {e}" + logger.error(err_msg) + raise FailedRollout(err_msg) + finally: + try: + environment.close() + except Exception as e: + logger.error(f"Error closing environment: {e}") + +def debug_save_training_text(text: TrainingText): + with open("debug_training_texts.jsonl", "a") as f: + f.write(json.dumps({"text": text.text, "n_predicted": text.n_predicted}, ensure_ascii=False) + "\n") \ No newline at end of file diff --git a/pipelinerl/domains/mcp/steps.py b/pipelinerl/domains/mcp/steps.py new file mode 100644 index 00000000..9b29a717 --- /dev/null +++ b/pipelinerl/domains/mcp/steps.py @@ -0,0 +1,13 @@ +from typing import Any, Literal +from pydantic import Field +from tapeagents.core import FinalObservation + + +class MathAnswer(FinalObservation): + """ + Action that indicates the agent has finished solving a math problem. + The final answer must be contained within \\boxed{} format. + """ + + kind: Literal["math_answer_action"] = "math_answer_action" + answer: Any = Field(description="Final answer in \\boxed{} format") diff --git a/pipelinerl/domains/miniwob/README.md b/pipelinerl/domains/miniwob/README.md new file mode 100644 index 00000000..e9af1b42 --- /dev/null +++ b/pipelinerl/domains/miniwob/README.md @@ -0,0 +1,34 @@ +# Miniwob example + +## Prerequesites + +### TapeAgents + +Clone [TapeAgents](https://github.com/ServiceNow/TapeAgents/) in your parent folder and install it. +```bash +cd .. +git clone git@github.com:ServiceNow/TapeAgents.git +cd TapeAgents +pip install -e . +pip install 'tapeagents[finetune,converters]=0.1.12' +cd ../PipelineRL +``` + +Make sure to add the TapeAgent folder to your python path. +```bash +export PYTHONPATH="/path/to/TapeAgents:$PYTHONPATH" +``` + +### Miniwob + +see setup here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/README.md + +### Playwright + +The environment server will need to have playwright installed. + +`playwright install` + +## Launch Command + +`python -m pipelinerl.launch --config-name miniwob environment.miniwob_url=file:///PATH/TO/miniwob-plusplus/miniwob/html/miniwob/` diff --git a/pipelinerl/miniwob/environment_server.py b/pipelinerl/domains/miniwob/environment_server.py similarity index 80% rename from pipelinerl/miniwob/environment_server.py rename to pipelinerl/domains/miniwob/environment_server.py index 13839f7a..b30f9ef7 100644 --- a/pipelinerl/miniwob/environment_server.py +++ b/pipelinerl/domains/miniwob/environment_server.py @@ -13,12 +13,14 @@ def __init__(self, exp_path: str, headless: bool = True, observation_format: str = "html", - max_session_inactivity_secs: int = 600, + env_call_timeout: int = 60, ): os.environ["MINIWOB_URL"] = miniwob_url + # Remote environment server configuration self.n_envs = n_envs self.host = host - self.max_session_inactivity_secs = max_session_inactivity_secs + self.env_call_timeout = env_call_timeout + # Individual web environment configuration self.web_env_target = web_env_target self.exp_path = exp_path self.headless = headless @@ -29,7 +31,7 @@ def launch(self, port: int): """ Serve the web environment in TapeAgent. """ - env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port, max_session_inactivity_secs=self.max_session_inactivity_secs) + env_server = EnvironmentServer(n_envs=self.n_envs, host=self.host, port=port, env_call_timeout=self.env_call_timeout) env_server.launch(OmegaConf.create({ "_target_": self.web_env_target, "exp_path": self.exp_path, diff --git a/pipelinerl/domains/miniwob/load_tasks.py b/pipelinerl/domains/miniwob/load_tasks.py new file mode 100644 index 00000000..a056a311 --- /dev/null +++ b/pipelinerl/domains/miniwob/load_tasks.py @@ -0,0 +1,216 @@ +import random +from browsergym.miniwob import ALL_MINIWOB_TASKS + +DEBUG_SPLIT = [ + "miniwob.buy-ticket", + "miniwob.bisect-angle", + "miniwob.choose-list", + "miniwob.click-checkboxes-large", + "miniwob.click-checkboxes-soft", +] +EASY_SPLIT = [ + "miniwob.click-color", + "miniwob.click-test-2", + "miniwob.click-test-transfer", + "miniwob.enter-password", + "miniwob.focus-text-2", + "miniwob.identify-shape", + "miniwob.navigate-tree", + "miniwob.phone-book", + "miniwob.read-table", + "miniwob.use-autocomplete", + "miniwob.use-autocomplete", + "miniwob.buy-ticket", + "miniwob.click-checkboxes-soft", + "miniwob.click-collapsible-2", + "miniwob.click-collapsible-2-nodelay", + "miniwob.click-collapsible-nodelay", + "miniwob.click-dialog-2", + "miniwob.click-tab-2", + "miniwob.click-tab-2-medium", + "miniwob.form-sequence-3", + "miniwob.hot-cold", + "miniwob.multi-orderings", + "miniwob.tic-tac-toe", + "miniwob.use-autocomplete-nodelay" +] +MASSIMO_TRAIN_SPLIT = [ + "miniwob.ascending-numbers", + "miniwob.bisect-angle", + "miniwob.book-flight", + "miniwob.choose-date", + "miniwob.choose-date-easy", + "miniwob.choose-date-medium", + "miniwob.choose-date-nodelay", + "miniwob.choose-list", + "miniwob.circle-center", + "miniwob.click-button-sequence", + "miniwob.click-checkboxes-soft", + "miniwob.click-checkboxes-transfer", + "miniwob.click-collapsible-2", + "miniwob.click-collapsible-2-nodelay", + "miniwob.click-collapsible-nodelay", + "miniwob.click-color", + "miniwob.click-dialog", + "miniwob.click-dialog-2", + "miniwob.click-link", + "miniwob.click-menu", + "miniwob.click-menu-2", + "miniwob.click-scroll-list", + "miniwob.click-shape", + "miniwob.click-tab", + "miniwob.click-tab-2", + "miniwob.click-tab-2-hard", + "miniwob.click-tab-2-medium", + "miniwob.click-test", + "miniwob.click-test-2", + "miniwob.click-test-transfer", + "miniwob.click-widget", + "miniwob.copy-paste", + "miniwob.copy-paste-2", + "miniwob.count-shape", + "miniwob.count-sides", + "miniwob.daily-calendar", + "miniwob.drag-box", + "miniwob.drag-circle", + "miniwob.drag-cube", + "miniwob.drag-items", + "miniwob.drag-items-grid", + "miniwob.drag-shapes", + "miniwob.drag-shapes-2", + "miniwob.drag-sort-numbers", + "miniwob.draw-circle", + "miniwob.draw-line", + "miniwob.email-inbox", + "miniwob.email-inbox-delete", + "miniwob.email-inbox-forward", + "miniwob.email-inbox-forward-nl", + "miniwob.email-inbox-forward-nl-turk", + "miniwob.email-inbox-important", + "miniwob.email-inbox-noscroll", + "miniwob.email-inbox-reply", + "miniwob.email-inbox-star-reply", + "miniwob.enter-date", + "miniwob.enter-text", + "miniwob.enter-text-dynamic", + "miniwob.enter-time", + "miniwob.find-greatest", + "miniwob.find-word", + "miniwob.focus-text-2", + "miniwob.form-sequence", + "miniwob.form-sequence-2", + "miniwob.generate-number", + "miniwob.grid-coordinate", + "miniwob.guess-number", + "miniwob.highlight-text", + "miniwob.hot-cold", + "miniwob.identify-shape", + "miniwob.login-user", + "miniwob.login-user-popup", + "miniwob.multi-layouts", + "miniwob.multi-orderings", + "miniwob.navigate-tree", + "miniwob.odd-or-even", + "miniwob.order-food", + "miniwob.phone-book", + "miniwob.read-table", + "miniwob.read-table-2", + "miniwob.resize-textarea", + "miniwob.right-angle", + "miniwob.scroll-text", + "miniwob.scroll-text-2", + "miniwob.search-engine", + "miniwob.sign-agreement", + "miniwob.simple-algebra", + "miniwob.social-media", + "miniwob.social-media-all", + "miniwob.social-media-some", + "miniwob.text-editor", + "miniwob.text-transform", + "miniwob.tic-tac-toe", + "miniwob.use-autocomplete", + "miniwob.use-autocomplete-nodelay", + "miniwob.use-colorwheel", + "miniwob.use-colorwheel-2", + "miniwob.use-spinner", + "miniwob.visual-addition", +] +MASSIMO_TEST_SPLIT = [ + "miniwob.buy-ticket", + "miniwob.click-button", + "miniwob.click-option", + "miniwob.click-pie-nodelay", + "miniwob.drag-single-shape", + "miniwob.email-inbox-nl-turk", + "miniwob.enter-text-2", + "miniwob.find-midpoint", + "miniwob.focus-text", + "miniwob.simple-arithmetic", + "miniwob.stock-market", + "miniwob.use-slider-2", + "miniwob.click-checkboxes", + "miniwob.click-checkboxes-large", + "miniwob.click-collapsible", + "miniwob.click-pie", + "miniwob.click-shades", + "miniwob.click-tab-2-easy", + "miniwob.enter-password", + "miniwob.form-sequence-3", + "miniwob.highlight-text-2", + "miniwob.unicode-test", + "miniwob.use-slider", +] +TRAIN_SPLIT = None +TEST_SPLIT = None + + +def load_tasks(dataset_names: list[str], train_split: float = 0.6, seeds: list[int] = [0, 1, 2, 3, 4]): + # set global variables if needed + global TRAIN_SPLIT, TEST_SPLIT + if TRAIN_SPLIT is None or TEST_SPLIT is None: + # Make a copy of tasks to avoid modifying the original + all_tasks = list(ALL_MINIWOB_TASKS) + # Use fixed seed for consistent shuffling + rng = random.Random(1406) + rng.shuffle(all_tasks) + + n_train_tasks = int(len(ALL_MINIWOB_TASKS) * train_split) + TRAIN_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[:n_train_tasks]] + TEST_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[n_train_tasks:]] + + tasks = [] + for name in dataset_names: + if name == "debug": + tasks.extend([ + # {"dataset": "miniwob.debug", "task": task, "seed": 0} for task in DEBUG_SPLIT + {"dataset": task, "task": task, "seed": 0} for task in DEBUG_SPLIT + ]) + elif name == "easy": + tasks.extend([ + # {"dataset": "miniwob.easy", "task": task, "seed": 0} for task in EASY_SPLIT + {"dataset": task, "task": task, "seed": 0} for task in EASY_SPLIT + ]) + elif name == "train": + tasks.extend([ + # {"dataset": "miniwob.train", "task": task, "seed": seed} + {"dataset": task, "task": task, "seed": seed} + for task in TRAIN_SPLIT for seed in seeds + ]) + elif name == "test": + tasks.extend([ + # {"dataset": "miniwob.test", "task": task, "seed": seed} + {"dataset": task, "task": task, "seed": seed} + for task in TEST_SPLIT for seed in seeds + ]) + elif name == "massimo_train": + tasks.extend([ + {"dataset": task, "task": task, "seed": seed} + for task in MASSIMO_TRAIN_SPLIT for seed in range(3,10) # seeds 0-2 are used for held out goals in Mass setup + ]) + elif name == "massimo_test": + tasks.extend([ + {"dataset": task, "task": task, "seed": seed} + for task in MASSIMO_TEST_SPLIT for seed in range(10) + ]) + return tasks + diff --git a/pipelinerl/domains/miniwob/rollouts.py b/pipelinerl/domains/miniwob/rollouts.py new file mode 100644 index 00000000..ec71ff8e --- /dev/null +++ b/pipelinerl/domains/miniwob/rollouts.py @@ -0,0 +1,341 @@ +import asyncio +import json +import logging +import os +import random +import time +import traceback + +import aiohttp +from examples.rl_webagent.steps import WebTape +from hydra.utils import instantiate +from omegaconf import DictConfig +from tapeagents.agent import DEFAULT, Agent +from tapeagents.core import LLMCall, LLMOutputParsingFailureAction, Observation +from tapeagents.io import save_json_tape +from tapeagents.llms.trainable import TrainableLLM +from tapeagents.orchestrator import async_execute_agent +from tapeagents.remote_environment import AsyncRemoteEnvironment +from tapeagents.tools.simple_browser import PageObservation + +from pipelinerl.async_llm import make_training_text +from pipelinerl.rollouts import BaseMetrics, RolloutResult +from pipelinerl.world import Job + +logger = logging.getLogger(__name__) + + +class MiniwobMetrics(BaseMetrics): + reward: float + success: bool + no_error: bool + no_answer: bool + overflow: bool + n_llm_calls: int + n_step_errors: int + n_page_observations: int + n_steps: int + total_execution_time: float + agent_execution_time: float + environment_execution_time: float + env_step_time: float + agent_step_time: float + + +def tape_contains_an_error(tape: WebTape) -> bool: + """ + Returns true if the tape ends with an error, ie if one of the following is true: + - the last step is an LLMOutputParsingFailureAction + - the tape metadata has an error + - the last step is a PageObservation with an error + """ + return ( + len(tape.steps) == 0 + or isinstance(tape.steps[-1], LLMOutputParsingFailureAction) + or tape.metadata.result.get("error") is not None + or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) + ) + + +async def check_env_server_health(env_job: Job, session: aiohttp.ClientSession) -> dict: + """Check environment server health via HTTP API.""" + try: + url = f"http://{env_job.hostname}:{env_job.port}/health" + async with session.get(url, timeout=5) as response: + if response.status == 200: + health_data = await response.json() + return { + "healthy": True, + "health_data": health_data, + "last_check": time.time() + } + else: + error_text = await response.text() + return {"healthy": False, "error_message": f"HTTP {response.status}: {error_text}", "last_check": time.time()} + except Exception as e: + exception_type = type(e).__name__ + exception_message = str(e) if str(e) else "No message available" + logger.exception(f"Error checking environment server health: {exception_type}: {exception_message}", stack_info=True) + return {"healthy": False, "error_message": f"Exception: {exception_type}: {exception_message}", "last_check": time.time(), "error_stacktrace": traceback.format_exc()} + + +async def generate_miniwob_rollout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, +) -> RolloutResult: + # choose a random environment server + # Generate environment + # Generate TapeAgent + # run the agent + # get llm calls from tape + # compute rewards + # get training text from llm calls + + start_time = time.time() + + # Overall timeout for the entire rollout to prevent hanging + rollout_timeout = getattr(cfg, 'rollout_timeout', 600) # 10 minutes default + + env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] + env_jobs_url_tried = [] + + # Try each environment server with health checks until one of them returns a rollout result + for _ in range(len(env_jobs)): + # Choose the next environment server to try randomly from the ones that have not been tried yet + env_job = random.choice([job for job in env_jobs if f"http://{job.hostname}:{job.port}" not in env_jobs_url_tried]) + env_job_url = f"http://{env_job.hostname}:{env_job.port}" + env_jobs_url_tried.append(env_job_url) + + # Check server health before using + health = await check_env_server_health(env_job, session) + if not health["healthy"]: + logger.warning(f"Environment server {env_job_url} is unhealthy: {health}") + logger.warning(f"Get health error stacktrace: {health['error_stacktrace']}") + continue + # Log health status for monitoring + if health["healthy"]: + logger.info(f"Using healthy environment server {env_job_url}: {health}") + + try: + # Execute the entire rollout with a timeout + return await asyncio.wait_for( + _execute_rollout_with_timeout(cfg, llm, problem, session, start_time, env_job_url), + timeout=rollout_timeout + ) + except asyncio.TimeoutError: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout timeout error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout timed out after {rollout_timeout} seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + except Exception as e: + health = await check_env_server_health(env_job, session) + if stack_trace := health.get("error_stacktrace"): + logger.warning(f"Get health error stacktrace: {stack_trace}") + logger.warning(f"Rollout failed error stacktrace: {traceback.format_exc()}") + logger.warning(f"Rollout failed for task {problem['dataset']}/{problem['task']}/{problem['seed']} on environment {env_job_url}. Health: {health}. Trying next server.") + continue + # If all servers failed + logger.error(f"All environment servers failed for task {problem['dataset']}/{problem['task']}/{problem['seed']}. Returning a failed rollout result.") + return _create_failed_rollout_result(problem, start_time, "all environment servers failed") + + +async def _execute_rollout_with_timeout( + cfg: DictConfig, + llm: TrainableLLM, + problem: dict, + session: aiohttp.ClientSession, + start_time: float, + env_job_url: str, +) -> RolloutResult: + # (2) Generate environment, TapeAgent, and run them to get a Tape + no_error = True # track if there was an error in the tape + environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore + async with environment.acontext(session, wait_for_env=True) as env: + start_attempts = cfg.start_attempts + t = time.perf_counter() + while start_attempts > 0: + try: + tape_dict, info = await env.start_task(problem) + if info.get("error"): + raise ValueError(info['error']) + break + except Exception as e: + start_attempts -= 1 + logger.warning(f"Failed to start task {problem['dataset']}/{problem['task']}/{problem['seed']}. {start_attempts} attempts remaining. Error: {e}") + if start_attempts <= 0: + logger.error(f"Failed to start task after all retry attempts: {e}") + no_error = False + tape_dict = {} + break + else: + logger.warning("Retry start task after 5 seconds.") + await asyncio.sleep(5) + logger.info( + f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds. Worker ID: {env.worker_id}. Tape dict: {tape_dict}" + ) + tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object + t = time.perf_counter() + if no_error: # only run the agent if the task started successfully + logger.info(f"Running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}") + agent_attempts = cfg.agent_attempts + while agent_attempts > 0: + # check if the worker is alive. + try: + # this will either raise RuntimeError if worker is not alive anymore, or return a dictionary with the worker status + worker_status = await env.check_worker_alive() + if worker_status.get("status") == "starting": + logger.warning(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is starting, waiting 5 seconds for it to be fully started.") + await asyncio.sleep(5) + continue + except Exception as e: + # if worker is dead, no need to retry + logger.exception(f"Worker {env.worker_id} for task {problem['dataset']}/{problem['task']}/{problem['seed']} and tape ID {tape.metadata.id} is dead. Error: {e}", stack_info=True) + no_error = False + break + # if worker is alive, run the agent + try: + actions = await env.a_actions() + tools_description = await env.a_tools_description() + agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) + agent.llms = {DEFAULT: llm} + tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) + # Check if the tape has an error from the orchestrator (e.g., SocketTimeoutError, RuntimeError: Worker is not alive, etc.) + if tape.metadata.error: + logger.error(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} returned a tape with error: {tape.metadata.error}") + raise ValueError(tape.metadata.error) + else: + # Success - break out of retry loop + logger.info(f"Agent execution for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id} finished successfully") + break + except Exception as e: + agent_attempts -= 1 + logger.warning(f"Error occurred while running agent for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}. {agent_attempts} attempts remaining. Error: {e}") + if agent_attempts <= 0: + logger.error(f"Agent execution failed after all retry attempts for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}: {e}") + no_error = False + break + else: + logger.warning(f"Retry agent execution after 5 seconds for task {problem['dataset']}/{problem['task']}/{problem['seed']} with worker ID: {env.worker_id} and tape ID {tape.metadata.id}.") + await asyncio.sleep(5) + logger.info( + f"Agent finished task {problem['dataset']}/{problem['task']}/{problem['seed']} in {time.perf_counter() - t:.2f} seconds with worker ID: {env.worker_id} and tape ID {tape.metadata.id}" + ) + tape.metadata.result.update({"total_execution_time": time.perf_counter() - t}) + + # save the tape as we go + if cfg.save_tapes: + save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) + + # (3) Compute rewards + obs_steps = [step for step in tape if isinstance(step, Observation)] + if obs_steps: + last_obs = obs_steps[-1] + # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 + # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L188 + # Let's take directly the RAW_REWARD_GLOBAL from the metadata + # raw_reward = last_obs.metadata.other.get("reward", 0.0) + raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) + else: + raw_reward = -1.0 + + no_error = no_error and not tape_contains_an_error(tape) + # get the number of LLMOutputParsingFailureAction in the tape + n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) + # get the number of PageObservation steps in the tape + n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) + + if cfg.reward_computation == "nico": + reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 + elif cfg.reward_computation == "massimo": + reward = float(raw_reward>0) + if reward == 0.0: + reward = -1.0 + reward *= 0.98 ** n_page_observations + else: + raise ValueError(f"Invalid reward configuration: {cfg.reward_computation}") + + # (3) Get LLM calls from Tape + llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] + n_llm_calls = len(llm_calls) + llm_calls: list[LLMCall] = [ + LLMCall(**step.metadata.other["llm_call"]) + if isinstance(step.metadata.other["llm_call"], dict) + else step.metadata.other["llm_call"] + for step in llm_calls + ] + + # (4) # For each LLM interaction in the tape, make a training example. + all_finished = 1 + prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] + output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] + training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] + for text in training_texts: + text.reward = reward + all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 + + latency = time.time() - start_time + agent_time = tape.metadata.result.get("agent_execution_time", -1.0) + env_time = tape.metadata.result.get("environment_execution_time", -1.0) + n_observations = len([s for s in tape.steps if isinstance(s, Observation)]) # TODO: is this not the same n_page_observations?? + n_other_steps = len(tape.steps) - n_observations + metrics = MiniwobMetrics( + reward=reward, + success=reward > 0.5, + no_error=no_error, + no_answer=reward < 0, + overflow=not all_finished, + n_llm_calls=n_llm_calls, + n_step_errors=n_step_errors, + n_page_observations=n_page_observations, + n_steps=len(tape.steps), + total_execution_time=tape.metadata.result.get("total_execution_time", -1.0), + agent_execution_time=agent_time, + environment_execution_time=env_time, + env_step_time=env_time / n_observations if env_time > 0 and n_observations > 0 else -1.0, + agent_step_time=agent_time / n_other_steps if agent_time > 0 and n_other_steps > 0 else -1.0, + ) + + return RolloutResult( + training_texts=training_texts, + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=prompt_tokens, + output_tokens=output_tokens, + ) + + +def _create_failed_rollout_result(problem: dict, start_time: float, error_type: str) -> RolloutResult: + """Create a failed rollout result for timeout or other errors.""" + latency = time.time() - start_time + + # Create empty training texts and metrics for failed rollout + metrics = MiniwobMetrics( + reward=-1.0, + success=False, + no_error=False, + no_answer=True, + overflow=False, + n_llm_calls=0, + n_step_errors=0, + n_page_observations=0, + n_steps=0, + total_execution_time=latency, + agent_execution_time=-1.0, + environment_execution_time=-1.0, + env_step_time=-1.0, + agent_step_time=-1.0, + ) + + return RolloutResult( + training_texts=[], + metrics=metrics, + latency=latency, + dataset_name=problem["dataset"], + prompt_tokens=[], + output_tokens=[], + ) diff --git a/pipelinerl/finetune/rl/__init__.py b/pipelinerl/finetune/rl/__init__.py index 57aa4fa7..e74b9a0b 100644 --- a/pipelinerl/finetune/rl/__init__.py +++ b/pipelinerl/finetune/rl/__init__.py @@ -260,6 +260,7 @@ def rl_step( ) approx_kl = torch.exp(log_ratio_ref_new_clamp) - log_ratio_ref_new_clamp - 1 # Schulman KL approx + approx_kl_new_old = torch.exp(log_ratio_new_old) - log_ratio_new_old - 1 # Schulman KL approx assert torch.isfinite(approx_kl).all(), f"approx_kl is not finite: {approx_kl}" entropy_bonus_coef = linear_decay_coef(current_step, max_step, config.entropy_bonus, config.final_entropy_bonus) @@ -337,6 +338,7 @@ def rl_step( "max_advantage": advantages[masks_shifted].max().item(), "min_advantage": advantages[masks_shifted].min().item(), "kl": sum_sum(approx_kl / num_labels_in_seq, masks_shifted, segments).item(), + "kl_new_old": sum_sum(approx_kl_new_old / num_labels_in_seq, masks_shifted, segments).item(), "max_kl": approx_kl[masks_shifted].max().item(), "min_kl": approx_kl[masks_shifted].min().item(), "policy_loss": sum_sum(policy_loss / num_labels_in_seq, masks_shifted, segments).item(), @@ -381,14 +383,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R """ Populates a dataset with reinforcement learning specific data columns including rewards, advantages, and token weights. - - Args: - dataset (Dataset): The input dataset to populate with RL data - eos_token_id (int): End of sequence token ID - config (RLConfig): Configuration object containing RL training parameters - - Returns: - Dataset: The dataset populated with RL-specific columns + Uses leave-one-out (LOO) reward mean: each rollout's baseline excludes its own reward. """ # Convert to pandas for processing df_init = pd.DataFrame(dataset) @@ -396,7 +391,7 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R # Step 1: calculate group-level statistics df_stats = df_init[["group_id", "rollout_index", "step_index"]].copy() - df_stats["num_tokens"] = df_init["input_ids"].apply(lambda x: len(x)) + df_stats["num_tokens"] = df_init["input_ids"].apply(len) # We assume that rewards for all tokens are the same df_stats["rollout_reward"] = df_init["rewards"].apply(lambda x: x[0]) # Check that the reward is the same for each step in the rollout @@ -406,15 +401,22 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R df_grouped = ( df_stats.groupby("group_id") .agg( - rollout_reward_mean=("rollout_reward", "mean"), + rollout_reward_sum=("rollout_reward", "sum"), + rollout_reward_count=("rollout_reward", "count"), rollout_reward_std=("rollout_reward", "std"), - group_tokens=("num_tokens", "mean"), + group_tokens=("num_tokens", "mean"), ) .reset_index() ) - assert df_grouped.columns.tolist() == ["group_id", "rollout_reward_mean", "rollout_reward_std", "group_tokens"] - - # Step 2: calculate advantages for each sample + assert df_grouped.columns.tolist() == [ + "group_id", + "rollout_reward_sum", + "rollout_reward_count", + "rollout_reward_std", + "group_tokens", + ] + + # Step 2: calculate advantages for each sample (with LOO mean) df_advantages = pd.merge( df_init[["group_id", "rollout_index", "step_index", "rewards"]], df_grouped, @@ -422,26 +424,37 @@ def populate_rl_data(dataset: list[dict[str, Any]], eos_token_id: int, config: R how="left" ) assert len(df_advantages) == len(df_init) + def calculate_advantages(row): rewards = row["rewards"] - mean = row["rollout_reward_mean"] + group_sum = row["rollout_reward_sum"] + group_count = row["rollout_reward_count"] + current_reward = rewards[0] # same reward across tokens in rollout + + # Leave-one-out mean + if group_count > 1: + loo_mean = (group_sum - current_reward) / (group_count - 1) + else: + loo_mean = current_reward # degenerate case: only one rollout in group + std = row["rollout_reward_std"] if config.divide_advantage_by_std: - advantages = [(reward - mean) / (np.nan_to_num(std) + 1e-4) for reward in rewards] + advantages = [(r - loo_mean) / (np.nan_to_num(std) + 1e-4) for r in rewards] else: - advantages = [(reward - mean) for reward in rewards] + advantages = [(r - loo_mean) for r in rewards] return advantages - df_advantages["advantages"] = df_advantages.apply( - calculate_advantages, - axis=1, + + df_advantages["advantages"] = df_advantages.apply(calculate_advantages, axis=1) + df_advantages = df_advantages.drop( + columns=["rewards", "rollout_reward_sum", "rollout_reward_count", "rollout_reward_std"] ) - df_advantages = df_advantages.drop(columns=["rewards", "rollout_reward_mean", "rollout_reward_std"]) - assert df_advantages.columns.tolist() == ["group_id", "rollout_index", "step_index", "group_tokens", "advantages"] + assert df_advantages.columns.tolist() == [ + "group_id", "rollout_index", "step_index", "group_tokens", "advantages" + ] # Step 3: bring advantages and group level stats back to the main df df = df_init.drop(columns=["advantages", "group_tokens"]) df = pd.merge(df, df_advantages, on=["group_id", "rollout_index", "step_index"], how="left") - # Debug print lengths of all dataframes assert len(df) == len(df_init) # Step 4: make token-level overflow and mean group length information @@ -450,7 +463,9 @@ def calculate_advantages(row): axis=1, ) df["group_tokens"] = df.apply(lambda row: [row["group_tokens"]] * len(row["input_ids"]), axis=1) - df["num_labels"] = df.apply(lambda row: [sum(1 for label in row["labels"] if label != -100)] * len(row["input_ids"]), axis=1) + df["num_labels"] = df.apply( + lambda row: [sum(1 for label in row["labels"] if label != -100)] * len(row["input_ids"]), axis=1 + ) # Step 5: move the results back to the dataset advantages_list = df["advantages"].tolist() diff --git a/pipelinerl/finetune_loop.py b/pipelinerl/finetune_loop.py index a91d1aa2..7b616c43 100644 --- a/pipelinerl/finetune_loop.py +++ b/pipelinerl/finetune_loop.py @@ -483,6 +483,7 @@ def run_finetuning_loop( finally: if actor_update_group: dist.destroy_process_group(actor_update_group) + raise RuntimeError("Finetuning loop finished, exiting worker thread") def rl_finetuning_worker( diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index b03ab8d7..e2109e34 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -1,6 +1,7 @@ import logging import math import os +import shlex import shutil import subprocess import sys @@ -18,8 +19,9 @@ logger = logging.getLogger(__name__) -# All the launch commands in this file pass the environment to child processes -os.environ["PYTHONPATH"] = f"/home/toolkit/TapeAgents" +# TODO: rm debug code +import tapeagents + os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["TORCH_DISABLE_SHARE_RDZV_TCP_STORE"] = "1" os.environ["HF_DATASETS_DISABLE_PROGRESS_BARS"] = "1" @@ -71,6 +73,13 @@ def validate_config(cfg: DictConfig): if not hasattr(cfg.finetune.rl, "value_loss_coef") or cfg.finetune.rl.value_loss_coef <= 0.0: raise ValueError("value_loss_coef must be greater than 0 when using causal-language-modeling-with-value-head") + if cfg.finetune.seq_length < cfg.vllm_config.vllm_kwargs.max_model_len: + raise ValueError( + f"seq_length {cfg.finetune.seq_length} must be greater than or equal to " + f"vllm_kwargs.max_model_len {cfg.vllm_config.vllm_kwargs.max_model_len}" + ) + + def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus: list[int], exp_dir: Path): kwargs = cfg.vllm_config.vllm_kwargs @@ -150,6 +159,29 @@ def run_actor_llm( str(world_map.weight_update_group_size), ] + # Provide deterministic rendezvous port defaults when env vars are absent. + # vLLM spins up a torch.distributed TCPStore using VLLM_PORT. On the remote + # scheduler we observed replica crashes (store collisions, connection + # refused) because every start script inherited the same default port. By + # exporting VLLM_PORT_BASE/VLLM_PORT_STRIDE we carve out a rendezvous range + # per actor_idx while keeping the public HTTP listener at 8080+local_idx. + env = dict(os.environ) + if "VLLM_PORT_BASE" not in env: + # Each rank gets 1000 ports; 43000 leaves room below. + env["VLLM_PORT_BASE"] = str(43000 + 1000 * world_map.my_rank) + logger.debug( + "Setting default VLLM_PORT_BASE=%s for rank %s", + env["VLLM_PORT_BASE"], world_map.my_rank, + ) + if "VLLM_PORT_STRIDE" not in env: + env["VLLM_PORT_STRIDE"] = "20" + + env_overrides = { + key: str(env[key]) + for key in ("VLLM_PORT_BASE", "VLLM_PORT_STRIDE") + if key in env + } + # Add vLLM kwargs as separate arguments if cfg.vllm_config.vllm_kwargs: for k, v in cfg.vllm_config.vllm_kwargs.items(): @@ -162,13 +194,13 @@ def run_actor_llm( gpu_str = ",".join([str(gpu) for gpu in gpus]) logger.info(f"Running actor_llm with command: {' '.join(cmd)} on gpus: {gpu_str}") - save_command(log_dir, cmd) + save_command(log_dir, cmd, env_overrides or None) log_file_path = os.path.join(log_dir, "stdout.log") err_file_path = os.path.join(log_dir, "stderr.log") with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: yield _popen( cmd, - env={**os.environ, "CUDA_VISIBLE_DEVICES": gpu_str}, + env={**env, "CUDA_VISIBLE_DEVICES": gpu_str}, stdout=log_file, stderr=err_file, ) @@ -365,14 +397,21 @@ def run_redis(cfg: DictConfig): yield _popen(cmd, env=dict(os.environ)) -def save_command(script_dir: Path, cmd): +def save_command(script_dir: Path, cmd, env: dict | None = None): os.makedirs(script_dir, exist_ok=True) script_path = script_dir / "start.sh" with open(script_path, "w") as f: f.write("#!/bin/bash\n") + f.write("set -e\n") + if env: + for key, value in sorted(env.items()): + quoted_value = shlex.quote(value) + f.write(f"export {key}={quoted_value}\n") # Properly quote arguments for the shell script - quoted_cmd = [f"'{arg}'" if " " in arg or "$" in arg else arg for arg in cmd] - f.write(" ".join(quoted_cmd) + "\n") + quoted_cmd = [shlex.quote(arg) for arg in cmd] + f.write("exec ") + f.write(" ".join(quoted_cmd)) + f.write("\n") os.chmod(script_path, 0o755) logger.info(f"Saved start script to {script_path}") @@ -537,6 +576,7 @@ def main(cfg: DictConfig): processes = [] + logger.info(f"TapeAgents loaded from: {tapeagents.__file__}") lead_launcher_stream = SingleStreamSpec(exp_path=exp_dir, topic="launcher_0") init_msg = {"exp_init": "true"} if world_map.my_rank == 0: @@ -576,6 +616,8 @@ def main(cfg: DictConfig): if cfg.debug.mode == "finetune": processes.extend(launch_jobs(cfg, world_map, ["finetune"])) + elif cfg.debug.mode == "llm": + processes.extend(launch_jobs(cfg, world_map, ["actor_llm"])) elif cfg.debug.mode == "actor": processes.extend(launch_jobs(cfg, world_map, ["actor", "environment", "actor_llm"])) elif cfg.debug.mode == "preprocessor": diff --git a/pipelinerl/miniwob/load_tasks.py b/pipelinerl/miniwob/load_tasks.py deleted file mode 100644 index e5056c80..00000000 --- a/pipelinerl/miniwob/load_tasks.py +++ /dev/null @@ -1,76 +0,0 @@ -import random -from browsergym.miniwob import ALL_MINIWOB_TASKS - -DEBUG_SPLIT = [ - "miniwob.buy-ticket", - "miniwob.bisect-angle", - "miniwob.choose-list", - "miniwob.click-checkboxes-large", - "miniwob.click-checkboxes-soft", -] -EASY_SPLIT = [ - "miniwob.click-color", - "miniwob.click-test-2", - "miniwob.click-test-transfer", - "miniwob.enter-password", - "miniwob.focus-text-2", - "miniwob.identify-shape", - "miniwob.navigate-tree", - "miniwob.phone-book", - "miniwob.read-table", - "miniwob.use-autocomplete", - "miniwob.use-autocomplete", - "miniwob.buy-ticket", - "miniwob.click-checkboxes-soft", - "miniwob.click-collapsible-2", - "miniwob.click-collapsible-2-nodelay", - "miniwob.click-collapsible-nodelay", - "miniwob.click-dialog-2", - "miniwob.click-tab-2", - "miniwob.click-tab-2-medium", - "miniwob.form-sequence-3", - "miniwob.hot-cold", - "miniwob.multi-orderings", - "miniwob.tic-tac-toe", - "miniwob.use-autocomplete-nodelay" -] -TRAIN_SPLIT = None -TEST_SPLIT = None - - -def load_tasks(dataset_names: list[str], train_split: float = 0.6, seeds: list[int] = [0, 1, 2, 3, 4]): - # set global variables if needed - global TRAIN_SPLIT, TEST_SPLIT - if TRAIN_SPLIT is None or TEST_SPLIT is None: - # Make a copy of tasks to avoid modifying the original - all_tasks = list(ALL_MINIWOB_TASKS) - # Use fixed seed for consistent shuffling - rng = random.Random(1406) - rng.shuffle(all_tasks) - - n_train_tasks = int(len(ALL_MINIWOB_TASKS) * train_split) - TRAIN_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[:n_train_tasks]] - TEST_SPLIT = [t.get_task_id() for t in ALL_MINIWOB_TASKS[n_train_tasks:]] - - tasks = [] - for name in dataset_names: - if name == "debug": - tasks.extend([ - {"dataset": "miniwob.debug", "task": task, "seed": 0} for task in DEBUG_SPLIT - ]) - elif name == "easy": - tasks.extend([ - {"dataset": "miniwob.easy", "task": task, "seed": 0} for task in EASY_SPLIT - ]) - elif name == "train": - tasks.extend([ - {"dataset": "miniwob.train", "task": task, "seed": seed} - for task in TRAIN_SPLIT for seed in seeds - ]) - elif name == "test": - tasks.extend([ - {"dataset": "miniwob.test", "task": task, "seed": seed} - for task in TEST_SPLIT for seed in seeds - ]) - return tasks - diff --git a/pipelinerl/miniwob/rollouts.py b/pipelinerl/miniwob/rollouts.py deleted file mode 100644 index bbf68860..00000000 --- a/pipelinerl/miniwob/rollouts.py +++ /dev/null @@ -1,152 +0,0 @@ - -import asyncio -import logging -import os -import random -import time -import aiohttp -from hydra.utils import instantiate -from omegaconf import DictConfig - -from pipelinerl.async_llm import llm_async_generate, make_training_text -from pipelinerl.rollouts import RolloutResult -from pipelinerl.world import Job -from tapeagents.agent import Agent, DEFAULT -from tapeagents.core import LLMOutputParsingFailureAction, Observation, LLMCall -from tapeagents.llms.trainable import TrainableLLM -from tapeagents.remote_environment import AsyncRemoteEnvironment -from tapeagents.tools.simple_browser import PageObservation -from tapeagents.orchestrator import async_execute_agent -from tapeagents.io import save_json_tape -from examples.rl_webagent.steps import WebTape - - -logger = logging.getLogger(__name__) - - -def tape_contains_an_error(tape: WebTape) -> bool: - """ - Returns true if the tape ends with an error, ie if one of the following is true: - - the last step is an LLMOutputParsingFailureAction - - the tape metadata has an error - - the last step is a PageObservation with an error - """ - return ( - isinstance(tape.steps[-1], LLMOutputParsingFailureAction) - or tape.metadata.result.get("error") is not None - or (isinstance(tape.steps[-1], PageObservation) and tape.steps[-1].error) - ) - - -async def generate_miniwob_rollout( - cfg: DictConfig, - llm: TrainableLLM, - problem: dict, - session: aiohttp.ClientSession, -) -> RolloutResult: - # choose a random environment server - # Generate environment - # Generate TapeAgent - # run the agent - # get llm calls from tape - # compute rewards - # get training text from llm calls - - start_time = time.time() - - # (1) Choose a random environment server - env_jobs = [Job(**job) for job in cfg.jobs if job["kind"] == "environment"] - # choose the env job randomly - env_job = random.choice(env_jobs) - assert env_job.port is not None - env_job_url = f"http://{env_job.hostname}:{env_job.port}" - - # (2) Generate environment, TapeAgent, and run them to get a Tape - environment = AsyncRemoteEnvironment(server_url=env_job_url) # type: ignore - async with environment.acontext(session, wait_for_env=True) as env: - start_attempts = cfg.start_attempts - t = time.perf_counter() - while True: - try: - tape_dict, _ = await env.start_task(problem) - break - except Exception as e: - start_attempts -= 1 - if start_attempts <= 0: - raise e - logger.warning(f"Failed to start task, retry after 5 seconds: {e}") - await asyncio.sleep(5) - logger.info(f"Task {problem['dataset']}/{problem['task']}/{problem['seed']} started in {time.perf_counter() - t:.2f} seconds") - tape: WebTape = WebTape(**tape_dict) # convert http response dict to WebTape object - t = time.perf_counter() - try: - actions = await env.a_actions() - tools_description = await env.a_tools_description() - logger.debug(f"Available tools: {tools_description}") - agent: Agent = instantiate(cfg.agent, known_actions=actions, tools_description=tools_description) - agent.llms = {DEFAULT: llm} - tape = await async_execute_agent(agent, tape, env, session, max_loops=cfg.agent_max_loops) - except Exception as e: - logger.error(f"Error occurred while running agent: {e}") - tape.metadata.result = {"execution_time": time.perf_counter() - t} - - # save the tape as we go - if cfg.save_tapes: - save_json_tape(tape, os.path.join(cfg.output_dir, "tapes"), tape.metadata.id) - - # (3) Compute rewards - last_obs = [step for step in tape if isinstance(step, Observation)][-1] - # in Miniwob, the observation "reward" is defined as RAW_REWARD_GLOBAL > 0 - # see here: https://github.com/ServiceNow/BrowserGym/blob/main/browsergym/miniwob/src/browsergym/miniwob/base.py#L183 - # Let's take directly the RAW_REWARD_GLOBAL from the metadata - # raw_reward = last_obs.metadata.other.get("reward", 0.0) - raw_reward = last_obs.metadata.other.get("info", {}).get("task_info", {}).get("REWARD_GLOBAL", -1.0) - no_error = not tape_contains_an_error(tape) - # get the number of LLMOutputParsingFailureAction in the tape - n_step_errors = len([step for step in tape.steps if isinstance(step, LLMOutputParsingFailureAction)]) - # get the number of PageObservation steps in the tape - n_page_observations = len([step for step in tape.steps if isinstance(step, PageObservation)]) - - reward = raw_reward * 0.99**n_step_errors if no_error and raw_reward >= 0 else -1.0 - - # (3) Get LLM calls from Tape - llm_calls = [step for step in tape.steps if step.metadata.other.get("llm_call") is not None] - n_llm_calls = len(llm_calls) - llm_calls: list[LLMCall] = [ - LLMCall(**step.metadata.other["llm_call"]) if isinstance(step.metadata.other["llm_call"], dict) - else step.metadata.other["llm_call"] - for step in llm_calls - ] - - # (4) # For each LLM interaction in the tape, make a training example. - all_finished = 0 - prompt_tokens = [llm_call.prompt_length_tokens for llm_call in llm_calls] - output_tokens = [llm_call.output_length_tokens for llm_call in llm_calls] - training_texts = [make_training_text(llm, llm_call) for llm_call in llm_calls] - for text in training_texts: - text.reward = reward - all_finished &= 1 if text.input_ids[-1] == llm.tokenizer.eos_token_id else 0 - - latency = time.time() - start_time - - metrics = { - "reward": reward, - "success": 1 if reward > 0.5 else 0, - "no_error": no_error, - "no_answer": 1 if reward < 0 else 0, - "overflow": 0 if all_finished else 1, - "n_llm_calls": n_llm_calls, - "n_step_errors": n_step_errors, - "n_page_observations": n_page_observations, - "n_steps": len(tape.steps), - } - - return RolloutResult( - training_texts=training_texts, - metrics=metrics, - latency=latency, - dataset_name=problem["dataset"], - prompt_tokens=prompt_tokens, - output_tokens=output_tokens, - ) - diff --git a/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja b/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja deleted file mode 100644 index a3bc9f02..00000000 --- a/pipelinerl/miniwob/tool_chat_template_llama3.1_json.jinja +++ /dev/null @@ -1,120 +0,0 @@ -{{- bos_token }} -{%- if custom_tools is defined %} - {%- set tools = custom_tools %} -{%- endif %} -{%- if not tools_in_user_message is defined %} - {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #} - {%- set tools_in_user_message = true %} -{%- endif %} -{%- if not date_string is defined %} - {%- if strftime_now is defined %} - {%- set date_string = strftime_now("%d %b %Y") %} - {%- else %} - {%- set date_string = "26 Jul 2024" %} - {%- endif %} -{%- endif %} -{%- if not tools is defined %} - {%- set tools = none %} -{%- endif %} - -{#- This block extracts the system message, so we can slot it into the right place. #} -{%- if messages[0]['role'] == 'system' %} - {%- if messages[0]['content'] is string %} - {%- set system_message = messages[0]['content']|trim %} - {%- else %} - {%- set system_message = messages[0]['content'][0]['text']|trim %} - {%- endif %} - {%- set messages = messages[1:] %} -{%- else %} - {%- if tools is not none %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} - {%- else %} - {%- set system_message = "" %} - {%- endif %} -{%- endif %} - -{#- System message #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if tools is not none %} - {{- "Environment: ipython\n" }} -{%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} -{%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} - -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and not tools is none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} - {{- "Given the following functions, please respond with a JSON for a function call " }} - {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- first_user_message + "<|eot_id|>"}} -{%- endif %} - -{%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} - {%- if message['content'] is string %} - {{- message['content'] | trim}} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- content['text'] | trim }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- '<|eot_id|>' }} - {%- elif 'tool_calls' in message %} - {%- if not message.tool_calls|length == 1 %} - {{- raise_exception("This model only supports single tool-calls at once!") }} - {%- endif %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} - {{- '{"name": "' + tool_call.name + '", ' }} - {{- '"parameters": ' }} - {{- tool_call.arguments | tojson }} - {{- "}" }} - {{- "<|eot_id|>" }} - {%- elif message.role == "tool" or message.role == "ipython" %} - {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is string %} - {{- { "output": message.content } | tojson }} - {%- else %} - {%- for content in message['content'] %} - {%- if content['type'] == 'text' %} - {{- { "output": content['text'] } | tojson }} - {%- endif %} - {%- endfor %} - {%- endif %} - {{- "<|eot_id|>" }} - {%- endif %} -{%- endfor %} -{%- if add_generation_prompt %} - {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} -{%- endif %} \ No newline at end of file diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 65e29b4b..2164c1ef 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -160,7 +160,18 @@ def preprocess_dataset( entry["step_index"] = entry["metadata"]["step_index"] if not isinstance(tokenizer.eos_token_id, int): raise ValueError(f"Tokenizer {tokenizer} does not have an eos_token_id") - dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + try: + dataset = populate_rl_data(dataset=dataset, eos_token_id=tokenizer.eos_token_id, config=rl_config) + except Exception as e: + logger.error(f"Error in populate_rl_data: {e}") + logger.error(f"Data: {data}") + logger.error(f"Dataset: {dataset}") + logger.error(f"Tokenizer: {tokenizer}") + logger.error(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}") + logger.error(f"RL config: {rl_config}") + logger.error(f"LLM: {llm}") + logger.error(f"Seq length: {seq_length}") + raise e return dataset @@ -573,6 +584,10 @@ def run_preprocessing_loop( sample_length = len(entry["input_ids"]) if current_length + sample_length > cfg.finetune.seq_length: + if len(current_batch) == 0: + raise ValueError( + f"sample_length is {sample_length}, but cfg.finetune.seq_length is {cfg.finetune.seq_length}" + ) time_to_write = True break # Current micro batch is full @@ -637,6 +652,7 @@ def run_preprocessing_loop( "preprocessor/queue/output": output_queue.qsize(), "preprocessor/filtered_out_samples": num_filtered_out, "preprocessor/total_filtered_out_samples": total_filtered_out, + "preprocessor/dropped_after_preprocessing": processed_entries_queue_popped_data, } if stats_aggregator.has_enough_data(): stats.update({"preprocessor/" + k: v for k, v in stats_aggregator.get_stats().items()}) diff --git a/pipelinerl/rl_tool_parser_plugin.py b/pipelinerl/rl_tool_parser_plugin.py new file mode 100644 index 00000000..12e6fc2d --- /dev/null +++ b/pipelinerl/rl_tool_parser_plugin.py @@ -0,0 +1,247 @@ +""" +Tool parser plugin for RL tool calling format. +""" + +import json +import re +from typing import Any # noqa: F401 +import logging + +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + ToolCall, + FunctionCall +) + + +@ToolParserManager.register_module("rl_tool") +class HermesRLToolParser(ToolParser): + """ + Tool parser for RL tool calling format using markers. + Supports both standard format and Apriel-style formats: + - [{...}, {...}] (preferred if present) + - [BEGIN FINAL RESPONSE] ... [END FINAL RESPONSE] wrapper + """ + + def __init__(self, tokenizer): + super().__init__(tokenizer) + + # Tool call markers + self.tool_call_start_token = "" + self.tool_call_end_token = "" + + # Regex pattern for parsing tool calls + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL + ) + + # Apriel-specific patterns + self.apriel_final_response_regex = re.compile( + r"\[BEGIN FINAL RESPONSE\](.*?)\[END FINAL RESPONSE\]", re.DOTALL + ) + # Prefer parsing aggregated tool calls from ... + # Be lenient: case-insensitive; tolerate missing closing tag by capturing to end. + self.apriel_tool_calls_regex = re.compile( + r"\s*(.*?)\s*(?:|$)", re.DOTALL | re.IGNORECASE + ) + + # State for streaming + self.current_tool_name_sent = False + self.prev_tool_call_arr = [] + self.current_tool_id = -1 + self.streamed_args_for_tool = [] + + def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from the model output. + + Args: + model_output: The raw model output string + request: The request object + + Returns: + ExtractedToolCallInformation with tool calls and metadata + """ + logger = logging.getLogger("pipelinerl.tool_parser") + # Ensure variable exists for any fallback references below + final_response_match = None + + try: + # 1) Apriel aggregated tool calls block has priority + tool_calls_matches = list(self.apriel_tool_calls_regex.finditer(model_output)) + if tool_calls_matches: + # Use the last match (in case of multiple blocks) + last_match = tool_calls_matches[-1] + tool_calls_json = last_match.group(1).strip() + parsed_calls = [] + try: + parsed_calls = json.loads(tool_calls_json) if tool_calls_json else [] + except Exception: + logger.debug("Failed to parse aggregated JSON; falling back", exc_info=True) + parsed_calls = [] + + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_calls): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed aggregated tool call", exc_info=True) + continue + + # Prefer final response content if present; otherwise empty string + final_response_match = self.apriel_final_response_regex.search(model_output) + content = final_response_match.group(1).strip() if final_response_match else "" + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 2) Try bare JSON tool-calls (no tags), but only if tools are declared in the request + # Accept either a list of {name, arguments} or a single dict + try: + tools_declared = bool(getattr(request, "tools", None)) + except Exception: + tools_declared = False + + if tools_declared: + candidate_strings: list[str] = [] + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + candidate_strings.append(final_response_match.group(1).strip()) + candidate_strings.append(model_output.strip()) + + for candidate in candidate_strings: + try: + parsed = json.loads(candidate) + except Exception: + continue + parsed_list = [] + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + parsed_list = [parsed] + elif isinstance(parsed, list) and all(isinstance(it, dict) for it in parsed): + parsed_list = [it for it in parsed if "name" in it and "arguments" in it] + if not parsed_list: + continue + tool_calls: list[ToolCall] = [] + for i, pc in enumerate(parsed_list): + try: + name = pc.get("name", "") + args_obj = pc.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + args_str = json.dumps(args_obj, ensure_ascii=False) + call_id = pc.get("id", f"call_{i}") + tool_calls.append( + ToolCall( + id=call_id, + type="function", + function=FunctionCall(name=str(name), arguments=args_str), + ) + ) + except Exception: + logger.debug("Skipping malformed bare-JSON tool call", exc_info=True) + continue + content = final_response_match.group(1).strip() if final_response_match else "" + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content, + ) + + # 3) Fallback: look for single blocks (legacy / other models) + content_to_search = model_output + final_response_match = self.apriel_final_response_regex.search(model_output) + if final_response_match: + final_response_content = final_response_match.group(1).strip() + if self.tool_call_start_token in final_response_content: + content_to_search = final_response_content + elif self.tool_call_start_token not in model_output: + # No tool calls found, return final response as content + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_content + ) + + # Quick check to avoid unnecessary processing + if self.tool_call_start_token not in content_to_search: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + + # Find all tool call matches + function_call_tuples = self.tool_call_regex.findall(content_to_search) + + # Parse JSON from matches + tool_calls = [] + for i, match in enumerate(function_call_tuples): + json_str = match[0] if match[0] else match[1] + try: + parsed_call = json.loads(json_str.strip()) + args_obj = parsed_call.get("arguments", {}) + if not isinstance(args_obj, (dict, list, str, int, float, bool)): + args_obj = {} + tool_call = ToolCall( + id=f"call_{i}", + type="function", + function=FunctionCall( + name=str(parsed_call.get("name", "")), + arguments=json.dumps(args_obj, ensure_ascii=False) + ) + ) + tool_calls.append(tool_call) + except Exception: + logger.debug("Skipping malformed JSON", exc_info=True) + continue + + # Determine content based on whether we found tool calls + if tool_calls and final_response_match: + # If we found tool calls in final response, use just the tool calls + content = "" + elif final_response_match: + # If we have final response but no tool calls there, use final response + content = final_response_match.group(1).strip() + else: + # Standard processing + content = model_output + + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content + ) + + except Exception: + # Never propagate exceptions to the server; log and return a safe fallback. + logger.exception("Tool parser encountered an exception; returning safe fallback.") + if final_response_match: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=final_response_match.group(1).strip() + ) + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output + ) + \ No newline at end of file diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index 2b0a252c..2378e2a2 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -6,14 +6,13 @@ import time from pathlib import Path import traceback -from typing import Dict, Mapping, List, Any, Union +from typing import Dict, Mapping, List, Any import numpy as np from omegaconf import DictConfig import psutil import requests from importlib.metadata import distributions from transformers import PreTrainedTokenizer -from collections import defaultdict from pipelinerl.world import Job from tapeagents.llms import LLMOutput @@ -239,6 +238,9 @@ def calculate_stats(stats: List | Dict[Any, Any]) -> Dict[str, float]: if not isinstance(stats, list): raise TypeError(f"Expected stats to be a list, got {type(stats)}") + if len(stats) == 0: + return {} + aggregated_stats = { "max": float(max(stats)), "min": float(min(stats)), @@ -293,19 +295,19 @@ def wait_for_inference_servers(urls: list[str]): def wait_for_environments(cfg: DictConfig): - """ - Wait for the verifier to be ready. - """ + """Wait for remote environment servers to report healthy.""" + if cfg.world.environment_mode != "remote": + return + env_jobs = [Job(**job) for job in cfg.jobs if job.kind == "environment"] for job in env_jobs: while True: url = f"http://{job.hostname}:{job.port}/health" - # use requests try: response = requests.get(url) if response.status_code == 200: break - except: + except requests.exceptions.RequestException: logger.info(f"Waiting for environment at {url} to be ready...") time.sleep(5.0) @@ -321,7 +323,7 @@ def better_crashing(entrypoint_name: str): # get process if of the current process process_id = os.getpid() terminate_with_children(process_id) - logger.error(f"I should not even be here...") + logger.error("I should not even be here...") import sys sys.exit(1) diff --git a/pipelinerl/vllm0.py b/pipelinerl/vllm0.py index 8cd023bd..6858c7cd 100644 --- a/pipelinerl/vllm0.py +++ b/pipelinerl/vllm0.py @@ -3,39 +3,39 @@ import logging import os import signal -from pydantic import TypeAdapter + import torch +import torch.distributed as dist import uvloop +from pydantic import TypeAdapter from vllm import AsyncLLMEngine -from vllm.utils import FlexibleArgumentParser, set_ulimit -from vllm.entrypoints.openai.cli_args import ( - make_arg_parser, - validate_parsed_serve_args, -) +from vllm._version import version +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, - create_server_socket, build_app, + create_server_socket, init_app_state, + run_server, +) +from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, ) -from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.logger import init_logger -from vllm._version import version -from vllm.worker.worker import Worker -from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper from vllm.executor.mp_distributed_executor import MultiprocessingDistributedExecutor +from vllm.executor.multiproc_worker_utils import ProcessWorkerWrapper +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest from vllm.usage.usage_lib import UsageContext -from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.utils import FlexibleArgumentParser, set_ulimit from vllm.worker.multi_step_model_runner import MultiStepModelRunner +from vllm.worker.multi_step_worker import MultiStepWorker +from vllm.worker.worker import Worker - -import torch.distributed as dist -from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest import pipelinerl.torch_utils +from pipelinerl.finetune_loop import TrainerMessage, WeightUpdateRequest logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign @@ -180,6 +180,25 @@ async def run_server(args, **uvicorn_kwargs) -> None: f"invalid tool call parser: {args.tool_call_parser} (chose from {{ {','.join(valide_tool_parses)} }})" ) + # Choose a unique rendezvous port per actor to avoid torch.distributed + # TCPStore collisions across concurrently launched vLLM processes. + try: + if "VLLM_PORT" not in os.environ: + actor_idx = getattr(args, "actor_llm_idx", None) + base_str = os.environ.get("VLLM_PORT_BASE", "") + stride_str = os.environ.get("VLLM_PORT_STRIDE", "10") + if actor_idx is not None and base_str.isdigit(): + base = int(base_str) + stride = int(stride_str) if stride_str.isdigit() else 10 + port = base + stride * int(actor_idx) + os.environ["VLLM_PORT"] = str(port) + logger.info( + "Using VLLM_PORT=%s (base=%s stride=%s actor_idx=%s)", + port, base, stride, actor_idx, + ) + except Exception as e: + logger.warning("Failed to set VLLM_PORT from actor_idx: %s", e) + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index be98f76f..38d1bc96 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -1,32 +1,32 @@ import logging import signal +from typing import Any, Protocol, runtime_checkable + import torch import uvloop -from vllm.utils import FlexibleArgumentParser, set_ulimit -from vllm.entrypoints.openai.cli_args import ( - make_arg_parser, - validate_parsed_serve_args, -) +from vllm._version import version +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.openai.api_server import ( - run_server, - create_server_socket, build_app, + create_server_socket, init_app_state, + run_server, +) +from vllm.entrypoints.openai.cli_args import ( + make_arg_parser, + validate_parsed_serve_args, ) -from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm._version import version from vllm.usage.usage_lib import UsageContext -from vllm.config import ModelConfig +from vllm.utils import FlexibleArgumentParser, set_ulimit from vllm.v1.engine.async_llm import AsyncLLM from vllm.v1.engine.core_client import AsyncMPClient from vllm.v1.worker.gpu_model_runner import GPUModelRunner - -from pipelinerl.finetune_loop import WeightUpdateRequest -from typing import Any, Protocol, runtime_checkable import pipelinerl.torch_utils +from pipelinerl.finetune_loop import WeightUpdateRequest logger = logging.getLogger(__name__) # configure this logger individually, in order to avoid messign diff --git a/pipelinerl/world.py b/pipelinerl/world.py index f41714e4..6a06fc9f 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -71,7 +71,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False): if place_inference_jobs: self._place_inference_jobs(cfg) self._place_pipeline_stages(cfg) - if cfg.environment: + if cfg.environment and cfg.world.environment_mode == "remote": self._place_environments(cfg) # Place the finetune workers on the remaining gpus, take all remaining GPUs @@ -188,7 +188,10 @@ def _place_pipeline_stages(self, cfg): self.add_job(kind="preprocessor", replica_idx=worker_idx, node_rank=node, gpus=[], cpu_heavy=True) def _place_environments(self, cfg): - for worker_idx in range(cfg.world.env_replicas): + # Scale environment servers to be the same as llm servers + env_replicas_per_actor = getattr(cfg.world, "env_replicas_per_actor", 1) + total_env_replicas = cfg.world.replicas * self.llms_per_actor * env_replicas_per_actor + for worker_idx in range(total_env_replicas): node = self.get_least_busy_node() envs_at_node = len([job for job in self.job_map[node] if job.kind == "environment"]) self.add_job( diff --git a/pyproject.toml b/pyproject.toml index e25e3ee0..4cd5c37c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "orjson==3.10.16", "redis==5.2.1", "hydra-core>=1.3.2", + "ray[default]~=2.47.1", ] [tool.setuptools.packages.find]