Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tinker_cookbook/completers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from dataclasses import dataclass
from typing import TypeAlias
from typing import Literal, TypeAlias

import tinker

Expand All @@ -16,19 +16,26 @@
# Interfaces

StopCondition: TypeAlias = list[str] | list[int]
StopReason: TypeAlias = Literal["length", "stop"]


@dataclass
class TokensWithLogprobs:
tokens: list[int]
maybe_logprobs: list[float] | None
stop_reason: StopReason = "stop" # Default for backward compatibility

@property
def logprobs(self) -> list[float]:
if self.maybe_logprobs is None:
raise ValueError("Logprobs are not available")
return self.maybe_logprobs

@property
def is_complete(self) -> bool:
"""Return True if generation completed normally (hit stop sequence)."""
return self.stop_reason == "stop"


class TokenCompleter:
async def __call__(
Expand Down Expand Up @@ -71,12 +78,18 @@ async def __call__(
),
)

# Extract tokens and logprobs from the first (and only) sample
sampled_tokens = sample_result.sequences[0].tokens
sampled_logprobs = sample_result.sequences[0].logprobs
# Extract tokens, logprobs, and stop_reason from the first (and only) sample
sampled_seq = sample_result.sequences[0]
sampled_tokens = sampled_seq.tokens
sampled_logprobs = sampled_seq.logprobs
stop_reason = sampled_seq.stop_reason # "length" or "stop"
assert sampled_logprobs is not None

return TokensWithLogprobs(tokens=sampled_tokens, maybe_logprobs=sampled_logprobs)
return TokensWithLogprobs(
tokens=sampled_tokens,
maybe_logprobs=sampled_logprobs,
stop_reason=stop_reason,
)


class TinkerMessageCompleter(MessageCompleter):
Expand Down
2 changes: 2 additions & 0 deletions tinker_cookbook/recipes/math_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class CLIConfig:

max_steps_off_policy: int | None = None
loss_fn: LossFnType = "importance_sampling"
filter_incomplete_trajectories: bool = False # Filter trajectories that hit max_tokens


def get_dataset_builder(
Expand Down Expand Up @@ -143,6 +144,7 @@ async def cli_main(cli_config: CLIConfig):
if cli_config.max_steps_off_policy is not None
else None,
loss_fn=cli_config.loss_fn,
filter_incomplete_trajectories=cli_config.filter_incomplete_trajectories,
)

cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists)
Expand Down
44 changes: 44 additions & 0 deletions tinker_cookbook/rl/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,47 @@ def remove_constant_reward_groups(
return trajectory_groups_P[0:1] # return singleton list in case empty
# list will cause problems
return new_groups


def filter_incomplete_trajectories(
trajectory_groups_P: List[TrajectoryGroup],
) -> tuple[List[TrajectoryGroup], dict[str, int]]:
"""Filter out trajectories with incomplete rollouts (hit max_tokens instead of stop sequence)."""
filtered_groups: list[TrajectoryGroup] = []
total_trajs = 0
incomplete_trajs = 0

for tg in trajectory_groups_P:
filtered_trajs = []
filtered_rewards = []
filtered_metrics = []

for traj, reward, metrics in zip(tg.trajectories_G, tg.final_rewards_G, tg.metrics_G):
total_trajs += 1
# Check if any transition is incomplete (hit max_tokens)
if any(not t.is_complete for t in traj.transitions):
incomplete_trajs += 1
continue
filtered_trajs.append(traj)
filtered_rewards.append(reward)
filtered_metrics.append(metrics)

if filtered_trajs:
filtered_groups.append(
TrajectoryGroup(
trajectories_G=filtered_trajs,
final_rewards_G=filtered_rewards,
metrics_G=filtered_metrics,
)
)

if incomplete_trajs > 0:
logger.warning(
f"Filtered {incomplete_trajs}/{total_trajs} incomplete trajectories "
f"(hit max_tokens limit)"
)

return filtered_groups, {
"filter/total_trajectories": total_trajs,
"filter/incomplete_trajectories": incomplete_trajs,
}
1 change: 1 addition & 0 deletions tinker_cookbook/rl/rollouts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ async def do_single_rollout(policy: TokenCompleter, env: Env) -> Trajectory:
ac=ac_with_logprobs,
reward=step_result.reward,
episode_done=step_result.episode_done,
is_complete=ac_with_logprobs.is_complete,
metrics=step_result.metrics,
)
transitions.append(transition)
Expand Down
22 changes: 22 additions & 0 deletions tinker_cookbook/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tinker_cookbook.rl.data_processing import (
assemble_training_data,
compute_advantages,
filter_incomplete_trajectories,
remove_constant_reward_groups,
)
from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics
Expand Down Expand Up @@ -252,6 +253,7 @@ class Config:
enable_trace: bool = False

remove_constant_reward_groups: bool = False
filter_incomplete_trajectories: bool = False
eval_every: int = 20 # 0 = disabled
save_every: int = 20 # 0 = disabled
load_checkpoint_path: str | None = None
Expand Down Expand Up @@ -712,6 +714,7 @@ async def prepare_minibatch(
model_name: str,
kl_penalty_coef: float,
kl_discount_factor: float,
do_filter_incomplete_trajectories: bool = False,
) -> tuple[list[tinker.Datum], dict[str, Any]]:
"""Converts the trajectories into a minibatch, and provides metrics about the minibatch"""

Expand All @@ -720,6 +723,14 @@ async def prepare_minibatch(
taglist_P = [env_group_builder.logging_tags() for env_group_builder in env_group_builders_P]
metrics.update(compute_trajectory_metrics(trajectory_groups_P, taglist_P))

# Filter incomplete trajectories (hit max_tokens instead of stop sequence)
if do_filter_incomplete_trajectories:
trajectory_groups_P, filter_stats = filter_incomplete_trajectories(trajectory_groups_P)
metrics.update(filter_stats)
if not trajectory_groups_P:
logger.warning("All trajectories were incomplete (hit max_tokens), skipping batch")
return [], metrics

# Print up to two trajectory groups
for traj_group in trajectory_groups_P[:2]:
print_group(traj_group, tokenizer)
Expand Down Expand Up @@ -846,9 +857,14 @@ async def do_train_step_streaming_and_get_sampling_client(
model_name=cfg.model_name,
kl_penalty_coef=cfg.kl_penalty_coef,
kl_discount_factor=cfg.kl_discount_factor,
do_filter_incomplete_trajectories=cfg.filter_incomplete_trajectories,
)
metrics.update(prepare_minibatch_metrics)

# Skip if all trajectories were filtered
if not data_D:
continue

# Accumulate gradients across multiple minibatches
with timed(
f"train/forward_backward_substep_{i_substep}_minibatch_{i_minibatch}", metrics
Expand Down Expand Up @@ -914,9 +930,15 @@ async def do_train_step_and_get_sampling_client(
model_name=cfg.model_name,
kl_penalty_coef=cfg.kl_penalty_coef,
kl_discount_factor=cfg.kl_discount_factor,
do_filter_incomplete_trajectories=cfg.filter_incomplete_trajectories,
)
metrics.update(prepare_minibatch_metrics)

# Handle case where all trajectories were filtered
if not data_D:
sampling_client = await training_client.save_weights_and_get_sampling_client_async()
return sampling_client, metrics

with timed("train", metrics):
training_logprobs_D = await train_step(
data_D,
Expand Down
1 change: 1 addition & 0 deletions tinker_cookbook/rl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Transition:
ac: TokensWithLogprobs
reward: float
episode_done: bool
is_complete: bool = True # True if action hit stop sequence, False if hit max_tokens
metrics: Metrics = field(default_factory=dict)


Expand Down