diff --git a/examples/multi-turn-math/README.md b/examples/multi-turn-math/README.md index 2e9dd16c6..e8078ff34 100644 --- a/examples/multi-turn-math/README.md +++ b/examples/multi-turn-math/README.md @@ -1,16 +1,22 @@ # Training a Multi-Turn GSM8K Math Agent in AReaL Files in this folder presents an example that train a multi-turn GSM8K math agent from -Qwen/Qwen2-1.5B-Instruct, using `ArealOpenAI` APIs and its `individual` mode to organize -training data and discount reward. Note that `sglang:disable_radix_cache` is set to true -to stablize training. +Qwen/Qwen2.5-1.5B-Instruct, using `ArealOpenAI` APIs and its `concat` mode to organize +training data and discount reward. # To run the example ```bash -python3 -m areal.launcher.ray examples/math/multi-turn/train.py \ - --config examples/math/multi-turn/config.yaml \ - experiment_name=gsm8k-math-multiturn trial_name=trial0 +python3 -m areal.launcher.ray examples/multi-turn-math/gsm8k_rl_mt.py \ + --config examples/multi-turn-math/gsm8k_grpo_mt.yaml \ + experiment_name=gsm8k-grpo-multiturn trial_name=trial0 +``` + +only the following config are added compared to the original `gsm8k_grpo.yaml` config: +```yaml +export_style: concat +agent_run_args: + max_turns: 2 ``` ## Reward Curve diff --git a/examples/multi-turn-math/config.yaml b/examples/multi-turn-math/gsm8k_grpo_mt.yaml similarity index 91% rename from examples/multi-turn-math/config.yaml rename to examples/multi-turn-math/gsm8k_grpo_mt.yaml index 0abc05118..45e31eb8c 100644 --- a/examples/multi-turn-math/config.yaml +++ b/examples/multi-turn-math/gsm8k_grpo_mt.yaml @@ -6,6 +6,10 @@ enable_offload: false total_train_epochs: 10 tokenizer_path: ${actor.path} +export_style: concat +agent_run_args: + max_turns: 2 + cluster: n_nodes: 1 n_gpus_per_node: 8 @@ -26,7 +30,7 @@ rollout: enable_rollout_tracing: false gconfig: - n_samples: 1 + n_samples: 4 min_new_tokens: 0 max_new_tokens: 1024 greedy: false @@ -35,13 +39,13 @@ gconfig: actor: experiment_name: ${experiment_name} trial_name: ${trial_name} - path: Qwen/Qwen2-1.5B-Instruct + path: Qwen/Qwen2.5-1.5B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: true dtype: bfloat16 mb_spec: - max_tokens_per_mb: 16384 + max_tokens_per_mb: 10240 optimizer: type: adam lr: 1.70e-5 @@ -52,7 +56,7 @@ actor: lr_scheduler_type: constant gradient_clipping: 1.0 warmup_steps_proportion: 0.001 - group_size: 1 + group_size: ${gconfig.n_samples} eps_clip: 0.4 temperature: ${gconfig.temperature} reward_scaling: 10.0 @@ -63,6 +67,10 @@ actor: use_decoupled_loss: true behav_imp_weight_cap: 5.0 dynamic_sampling: false + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} adv_norm: mean_level: batch std_level: batch @@ -88,7 +96,6 @@ sglang: max_running_requests: null context_length: 32768 mem_fraction_static: 0.8 - disable_radix_cache: true vllm: model: ${actor.path} @@ -147,10 +154,6 @@ stats_logger: wandb: mode: disabled -n_trajs: 2 -max_turns: 8 -max_tokens_per_trajectory: 8192 - launcher: inference_server_cpus_per_gpu: 4 inference_server_mem_per_gpu: 32768 diff --git a/examples/multi-turn-math/train.py b/examples/multi-turn-math/gsm8k_rl_mt.py similarity index 56% rename from examples/multi-turn-math/train.py rename to examples/multi-turn-math/gsm8k_rl_mt.py index d28990d17..89de10c49 100644 --- a/examples/multi-turn-math/train.py +++ b/examples/multi-turn-math/gsm8k_rl_mt.py @@ -1,7 +1,10 @@ import asyncio import os +import sys +from collections.abc import Callable from dataclasses import dataclass, field +from openai.types.chat import ChatCompletion from transformers import PreTrainedTokenizerFast from areal.api.cli_args import GenerationHyperparameters, GRPOConfig, load_expr_config @@ -15,28 +18,6 @@ from areal.utils.stats_logger import StatsLogger -@dataclass -class AgentRLConfig(GRPOConfig): - n_trajs: int = field( - default=1, - metadata={ - "help": "We could collect multiple trajectories for a single query. By default n_trajs=1." - }, - ) - max_turns: int = field( - default=8, - metadata={ - "help": "Maximum number of turns per trajectory. By default max_turns=32." - }, - ) - max_tokens_per_trajectory: int = field( - default=32768, - metadata={ - "help": "Maximum number of tokens per trajectory. By default max_tokens_per_trajectory=32768." - }, - ) - - def gsm8k_reward_fn(result, answer): from areal.reward.math_parser import process_results @@ -46,28 +27,21 @@ def gsm8k_reward_fn(result, answer): class MultiTurnMathAgent: def __init__( self, - tokenizer: PreTrainedTokenizerFast, - max_tokens_per_turn: int = 1024, - max_turns: int = 8, - max_total_tokens: int = 32768, + gconfig: GenerationHyperparameters, + reward_fn: Callable[[str, str], float | int], + max_turns: int = 2, ): - self.tokenizer = tokenizer - self.max_tokens_per_turn = max_tokens_per_turn + self.gconfig = gconfig self.max_turns = max_turns - self.max_total_tokens = max_total_tokens - self.async_reward_fn = AsyncRewardWrapper(gsm8k_reward_fn) + self.async_reward_fn = AsyncRewardWrapper(reward_fn) async def run_agent(self, data, client: ArealOpenAI): messages = data["messages"].copy() - num_turns_left = self.max_turns - completions = [] - while num_turns_left > 0: - response = await client.chat.completions.create( + for _ in range(self.max_turns): + response: ChatCompletion = await client.chat.completions.create( messages=messages, - temperature=1.0, - max_completion_tokens=self.max_tokens_per_turn, + **self.gconfig.to_openai_args_dict(), ) - completions.append(response) message = response.choices[0].message messages.append(message) reward = await self.async_reward_fn( @@ -84,43 +58,46 @@ async def run_agent(self, data, client: ArealOpenAI): "Please carefully read the original question, check the previous errors, and try to answer it again.", } ) - num_turns_left -= 1 return reward class MultiturnRLVRWorkflow(RolloutWorkflow): def __init__( self, + reward_fn: Callable[[str, str], float | int], gconfig: GenerationHyperparameters, tokenizer: PreTrainedTokenizerFast, dump_dir: str | None = None, rollout_stat_scope: str = "rollout", - n_trajs: int = 1, - max_tokens: int = 32768, - max_turns: int = 8, + export_style: str = "concat", + max_turns: int = 2, ): - # NOTE(refactor): stop tokens are not used in this workflow, adding stop and pad token ids may not be necessary - self.gconfig = gconfig.new_with_stop_and_pad_token_ids(tokenizer) - self.gconfig.n_samples = 1 + self.n_trajs = gconfig.n_samples self.tokenizer = tokenizer self.dump_dir = dump_dir - self.max_tokens = max_tokens self.rollout_stat_scope = rollout_stat_scope + self.export_style = export_style + if export_style not in ["individual", "concat"]: + raise ValueError(f"Invalid export style: {export_style}") + self.chat_template_type = "concat" if export_style == "concat" else "hf" + if self.dump_dir is not None and not os.path.exists(self.dump_dir): os.makedirs(self.dump_dir, exist_ok=True) # Search hyper-parameters - self.n_trajs = n_trajs self.agent = MultiTurnMathAgent( - tokenizer=self.tokenizer, - max_tokens_per_turn=self.gconfig.max_new_tokens, + gconfig=gconfig.new(n_samples=1), + reward_fn=reward_fn, max_turns=max_turns, - max_total_tokens=max_tokens, ) async def arun_episode(self, engine, data): clients = [ - ArealOpenAI(engine=engine, tokenizer=self.tokenizer) + ArealOpenAI( + engine=engine, + tokenizer=self.tokenizer, + chat_template_type=self.chat_template_type, + ) for _ in range(self.n_trajs) ] @@ -140,40 +117,68 @@ async def arun_episode(self, engine, data): completions_with_reward = {} for client in clients: client.apply_reward_discount(turn_discount=0.9) - completions = client.export_interactions(style="individual") + completions = client.export_interactions(style=self.export_style) completions_with_reward.update(completions) return completions_with_reward -def main(args): - config, _ = load_expr_config(args, AgentRLConfig) +@dataclass +class MultiTurnGRPOConfig(GRPOConfig): + agent_run_args: dict = field( + default_factory=dict, + metadata={"help": "Arguments for running the agent."}, + ) + export_style: str = field( + default="concat", + metadata={ + "help": "Export style for the completions. By default export_style=concat." + }, + ) + +def main(args): + config, _ = load_expr_config(args, MultiTurnGRPOConfig) tokenizer = load_hf_tokenizer(config.tokenizer_path) - # Load dataset train_dataset = get_custom_dataset( - split="train", dataset_config=config.train_dataset, tokenizer=tokenizer + split="train", + dataset_config=config.train_dataset, + tokenizer=tokenizer, ) - # Create trainer (no valid_dataset for this example) - with PPOTrainer(config, train_dataset, valid_dataset=None) as trainer: - # Create rollout workflow + valid_dataset = get_custom_dataset( + split="test", + dataset_config=config.valid_dataset, + tokenizer=tokenizer, + ) + + with PPOTrainer( + config, + train_dataset=train_dataset, + valid_dataset=valid_dataset, + ) as trainer: + max_turns = config.agent_run_args.get("max_turns", 2) + log_path = StatsLogger.get_log_path(config.stats_logger) + workflow = MultiturnRLVRWorkflow( + reward_fn=gsm8k_reward_fn, gconfig=config.gconfig, tokenizer=trainer.tokenizer, - n_trajs=config.n_trajs, - max_tokens=config.max_tokens_per_trajectory, - max_turns=config.max_turns, - dump_dir=os.path.join( - StatsLogger.get_log_path(config.stats_logger), "generated" - ), + dump_dir=os.path.join(log_path, "generated"), + export_style=config.export_style, + max_turns=max_turns, ) - - # Run training - trainer.train(workflow, eval_workflow=None) + eval_workflow = MultiturnRLVRWorkflow( + reward_fn=gsm8k_reward_fn, + gconfig=config.gconfig.new(temperature=0.6, n_samples=1), + tokenizer=trainer.tokenizer, + rollout_stat_scope="eval-rollout", + dump_dir=os.path.join(log_path, "generated-eval"), + export_style=config.export_style, + max_turns=max_turns, + ) + trainer.train(workflow, eval_workflow) if __name__ == "__main__": - import sys - main(sys.argv[1:]) diff --git a/examples/multi-turn-math/reward_curve.png b/examples/multi-turn-math/reward_curve.png index 93ce22d44..241ddfbc9 100644 Binary files a/examples/multi-turn-math/reward_curve.png and b/examples/multi-turn-math/reward_curve.png differ