Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions examples/multi-turn-math/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
# Training a Multi-Turn GSM8K Math Agent in AReaL

Files in this folder presents an example that train a multi-turn GSM8K math agent from
Qwen/Qwen2-1.5B-Instruct, using `ArealOpenAI` APIs and its `individual` mode to organize
training data and discount reward. Note that `sglang:disable_radix_cache` is set to true
to stablize training.
Qwen/Qwen2.5-1.5B-Instruct, using `ArealOpenAI` APIs and its `concat` mode to organize
training data and discount reward.

# To run the example

```bash
python3 -m areal.launcher.ray examples/math/multi-turn/train.py \
--config examples/math/multi-turn/config.yaml \
experiment_name=gsm8k-math-multiturn trial_name=trial0
python3 -m areal.launcher.ray examples/multi-turn-math/gsm8k_rl_mt.py \
--config examples/multi-turn-math/gsm8k_grpo_mt.yaml \
experiment_name=gsm8k-grpo-multiturn trial_name=trial0
```

only the following config are added compared to the original `gsm8k_grpo.yaml` config:
```yaml
export_style: concat
agent_run_args:
max_turns: 2
```
## Reward Curve
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ enable_offload: false
total_train_epochs: 10
tokenizer_path: ${actor.path}

export_style: concat
agent_run_args:
max_turns: 2

cluster:
n_nodes: 1
n_gpus_per_node: 8
Expand All @@ -26,7 +30,7 @@ rollout:
enable_rollout_tracing: false

gconfig:
n_samples: 1
n_samples: 4
min_new_tokens: 0
max_new_tokens: 1024
greedy: false
Expand All @@ -35,13 +39,13 @@ gconfig:
actor:
experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: Qwen/Qwen2-1.5B-Instruct
path: Qwen/Qwen2.5-1.5B-Instruct
init_from_scratch: false
disable_dropout: true
gradient_checkpointing: true
dtype: bfloat16
mb_spec:
max_tokens_per_mb: 16384
max_tokens_per_mb: 10240
optimizer:
type: adam
lr: 1.70e-5
Expand All @@ -52,7 +56,7 @@ actor:
lr_scheduler_type: constant
gradient_clipping: 1.0
warmup_steps_proportion: 0.001
group_size: 1
group_size: ${gconfig.n_samples}
eps_clip: 0.4
temperature: ${gconfig.temperature}
reward_scaling: 10.0
Expand All @@ -63,6 +67,10 @@ actor:
use_decoupled_loss: true
behav_imp_weight_cap: 5.0
dynamic_sampling: false
reward_norm:
mean_level: group
std_level: group
group_size: ${gconfig.n_samples}
adv_norm:
mean_level: batch
std_level: batch
Expand All @@ -88,7 +96,6 @@ sglang:
max_running_requests: null
context_length: 32768
mem_fraction_static: 0.8
disable_radix_cache: true

vllm:
model: ${actor.path}
Expand Down Expand Up @@ -147,10 +154,6 @@ stats_logger:
wandb:
mode: disabled

n_trajs: 2
max_turns: 8
max_tokens_per_trajectory: 8192

launcher:
inference_server_cpus_per_gpu: 4
inference_server_mem_per_gpu: 32768
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import os
import sys
from collections.abc import Callable
from dataclasses import dataclass, field

from openai.types.chat import ChatCompletion
from transformers import PreTrainedTokenizerFast

from areal.api.cli_args import GenerationHyperparameters, GRPOConfig, load_expr_config
Expand All @@ -15,28 +18,6 @@
from areal.utils.stats_logger import StatsLogger


@dataclass
class AgentRLConfig(GRPOConfig):
n_trajs: int = field(
default=1,
metadata={
"help": "We could collect multiple trajectories for a single query. By default n_trajs=1."
},
)
max_turns: int = field(
default=8,
metadata={
"help": "Maximum number of turns per trajectory. By default max_turns=32."
},
)
max_tokens_per_trajectory: int = field(
default=32768,
metadata={
"help": "Maximum number of tokens per trajectory. By default max_tokens_per_trajectory=32768."
},
)


def gsm8k_reward_fn(result, answer):
from areal.reward.math_parser import process_results

Expand All @@ -46,28 +27,21 @@ def gsm8k_reward_fn(result, answer):
class MultiTurnMathAgent:
def __init__(
self,
tokenizer: PreTrainedTokenizerFast,
max_tokens_per_turn: int = 1024,
max_turns: int = 8,
max_total_tokens: int = 32768,
gconfig: GenerationHyperparameters,
reward_fn: Callable[[str, str], float | int],
max_turns: int = 2,
):
self.tokenizer = tokenizer
self.max_tokens_per_turn = max_tokens_per_turn
self.gconfig = gconfig
self.max_turns = max_turns
self.max_total_tokens = max_total_tokens
self.async_reward_fn = AsyncRewardWrapper(gsm8k_reward_fn)
self.async_reward_fn = AsyncRewardWrapper(reward_fn)

async def run_agent(self, data, client: ArealOpenAI):
messages = data["messages"].copy()
num_turns_left = self.max_turns
completions = []
while num_turns_left > 0:
response = await client.chat.completions.create(
for _ in range(self.max_turns):
response: ChatCompletion = await client.chat.completions.create(
messages=messages,
temperature=1.0,
max_completion_tokens=self.max_tokens_per_turn,
**self.gconfig.to_openai_args_dict(),
)
completions.append(response)
message = response.choices[0].message
messages.append(message)
reward = await self.async_reward_fn(
Expand All @@ -84,43 +58,46 @@ async def run_agent(self, data, client: ArealOpenAI):
"Please carefully read the original question, check the previous errors, and try to answer it again.",
}
)
num_turns_left -= 1
return reward


class MultiturnRLVRWorkflow(RolloutWorkflow):
def __init__(
self,
reward_fn: Callable[[str, str], float | int],
gconfig: GenerationHyperparameters,
tokenizer: PreTrainedTokenizerFast,
dump_dir: str | None = None,
rollout_stat_scope: str = "rollout",
n_trajs: int = 1,
max_tokens: int = 32768,
max_turns: int = 8,
export_style: str = "concat",
max_turns: int = 2,
):
# NOTE(refactor): stop tokens are not used in this workflow, adding stop and pad token ids may not be necessary
self.gconfig = gconfig.new_with_stop_and_pad_token_ids(tokenizer)
self.gconfig.n_samples = 1
self.n_trajs = gconfig.n_samples
self.tokenizer = tokenizer
self.dump_dir = dump_dir
self.max_tokens = max_tokens
self.rollout_stat_scope = rollout_stat_scope
self.export_style = export_style
if export_style not in ["individual", "concat"]:
raise ValueError(f"Invalid export style: {export_style}")
self.chat_template_type = "concat" if export_style == "concat" else "hf"

if self.dump_dir is not None and not os.path.exists(self.dump_dir):
os.makedirs(self.dump_dir, exist_ok=True)

# Search hyper-parameters
self.n_trajs = n_trajs
self.agent = MultiTurnMathAgent(
tokenizer=self.tokenizer,
max_tokens_per_turn=self.gconfig.max_new_tokens,
gconfig=gconfig.new(n_samples=1),
reward_fn=reward_fn,
max_turns=max_turns,
max_total_tokens=max_tokens,
)

async def arun_episode(self, engine, data):
clients = [
ArealOpenAI(engine=engine, tokenizer=self.tokenizer)
ArealOpenAI(
engine=engine,
tokenizer=self.tokenizer,
chat_template_type=self.chat_template_type,
)
for _ in range(self.n_trajs)
]

Expand All @@ -140,40 +117,68 @@ async def arun_episode(self, engine, data):
completions_with_reward = {}
for client in clients:
client.apply_reward_discount(turn_discount=0.9)
completions = client.export_interactions(style="individual")
completions = client.export_interactions(style=self.export_style)
completions_with_reward.update(completions)
return completions_with_reward


def main(args):
config, _ = load_expr_config(args, AgentRLConfig)
@dataclass
class MultiTurnGRPOConfig(GRPOConfig):
agent_run_args: dict = field(
default_factory=dict,
metadata={"help": "Arguments for running the agent."},
)
export_style: str = field(
default="concat",
metadata={
"help": "Export style for the completions. By default export_style=concat."
},
)


def main(args):
config, _ = load_expr_config(args, MultiTurnGRPOConfig)
tokenizer = load_hf_tokenizer(config.tokenizer_path)

# Load dataset
train_dataset = get_custom_dataset(
split="train", dataset_config=config.train_dataset, tokenizer=tokenizer
split="train",
dataset_config=config.train_dataset,
tokenizer=tokenizer,
)

# Create trainer (no valid_dataset for this example)
with PPOTrainer(config, train_dataset, valid_dataset=None) as trainer:
# Create rollout workflow
valid_dataset = get_custom_dataset(
split="test",
dataset_config=config.valid_dataset,
tokenizer=tokenizer,
)

with PPOTrainer(
config,
train_dataset=train_dataset,
valid_dataset=valid_dataset,
) as trainer:
max_turns = config.agent_run_args.get("max_turns", 2)
log_path = StatsLogger.get_log_path(config.stats_logger)

workflow = MultiturnRLVRWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig,
tokenizer=trainer.tokenizer,
n_trajs=config.n_trajs,
max_tokens=config.max_tokens_per_trajectory,
max_turns=config.max_turns,
dump_dir=os.path.join(
StatsLogger.get_log_path(config.stats_logger), "generated"
),
dump_dir=os.path.join(log_path, "generated"),
export_style=config.export_style,
max_turns=max_turns,
)

# Run training
trainer.train(workflow, eval_workflow=None)
eval_workflow = MultiturnRLVRWorkflow(
reward_fn=gsm8k_reward_fn,
gconfig=config.gconfig.new(temperature=0.6, n_samples=1),
tokenizer=trainer.tokenizer,
rollout_stat_scope="eval-rollout",
dump_dir=os.path.join(log_path, "generated-eval"),
export_style=config.export_style,
max_turns=max_turns,
)
trainer.train(workflow, eval_workflow)


if __name__ == "__main__":
import sys

main(sys.argv[1:])
Binary file modified examples/multi-turn-math/reward_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.