diff --git a/tinker_cookbook/completers.py b/tinker_cookbook/completers.py index 6ed9e214..13b72bc5 100644 --- a/tinker_cookbook/completers.py +++ b/tinker_cookbook/completers.py @@ -7,7 +7,7 @@ """ from dataclasses import dataclass -from typing import TypeAlias +from typing import Literal, TypeAlias import tinker @@ -16,12 +16,14 @@ # Interfaces StopCondition: TypeAlias = list[str] | list[int] +StopReason: TypeAlias = Literal["length", "stop"] @dataclass class TokensWithLogprobs: tokens: list[int] maybe_logprobs: list[float] | None + stop_reason: StopReason = "stop" # Default for backward compatibility @property def logprobs(self) -> list[float]: @@ -29,6 +31,11 @@ def logprobs(self) -> list[float]: raise ValueError("Logprobs are not available") return self.maybe_logprobs + @property + def is_complete(self) -> bool: + """Return True if generation completed normally (hit stop sequence).""" + return self.stop_reason == "stop" + class TokenCompleter: async def __call__( @@ -71,12 +78,18 @@ async def __call__( ), ) - # Extract tokens and logprobs from the first (and only) sample - sampled_tokens = sample_result.sequences[0].tokens - sampled_logprobs = sample_result.sequences[0].logprobs + # Extract tokens, logprobs, and stop_reason from the first (and only) sample + sampled_seq = sample_result.sequences[0] + sampled_tokens = sampled_seq.tokens + sampled_logprobs = sampled_seq.logprobs + stop_reason = sampled_seq.stop_reason # "length" or "stop" assert sampled_logprobs is not None - return TokensWithLogprobs(tokens=sampled_tokens, maybe_logprobs=sampled_logprobs) + return TokensWithLogprobs( + tokens=sampled_tokens, + maybe_logprobs=sampled_logprobs, + stop_reason=stop_reason, + ) class TinkerMessageCompleter(MessageCompleter): diff --git a/tinker_cookbook/recipes/math_rl/train.py b/tinker_cookbook/recipes/math_rl/train.py index c93fac13..30ac1619 100644 --- a/tinker_cookbook/recipes/math_rl/train.py +++ b/tinker_cookbook/recipes/math_rl/train.py @@ -60,6 +60,7 @@ class CLIConfig: max_steps_off_policy: int | None = None loss_fn: LossFnType = "importance_sampling" + filter_incomplete_trajectories: bool = False # Filter trajectories that hit max_tokens def get_dataset_builder( @@ -143,6 +144,7 @@ async def cli_main(cli_config: CLIConfig): if cli_config.max_steps_off_policy is not None else None, loss_fn=cli_config.loss_fn, + filter_incomplete_trajectories=cli_config.filter_incomplete_trajectories, ) cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) diff --git a/tinker_cookbook/rl/data_processing.py b/tinker_cookbook/rl/data_processing.py index 1ce51fe8..7de93ba3 100644 --- a/tinker_cookbook/rl/data_processing.py +++ b/tinker_cookbook/rl/data_processing.py @@ -207,3 +207,47 @@ def remove_constant_reward_groups( return trajectory_groups_P[0:1] # return singleton list in case empty # list will cause problems return new_groups + + +def filter_incomplete_trajectories( + trajectory_groups_P: List[TrajectoryGroup], +) -> tuple[List[TrajectoryGroup], dict[str, int]]: + """Filter out trajectories with incomplete rollouts (hit max_tokens instead of stop sequence).""" + filtered_groups: list[TrajectoryGroup] = [] + total_trajs = 0 + incomplete_trajs = 0 + + for tg in trajectory_groups_P: + filtered_trajs = [] + filtered_rewards = [] + filtered_metrics = [] + + for traj, reward, metrics in zip(tg.trajectories_G, tg.final_rewards_G, tg.metrics_G): + total_trajs += 1 + # Check if any transition is incomplete (hit max_tokens) + if any(not t.is_complete for t in traj.transitions): + incomplete_trajs += 1 + continue + filtered_trajs.append(traj) + filtered_rewards.append(reward) + filtered_metrics.append(metrics) + + if filtered_trajs: + filtered_groups.append( + TrajectoryGroup( + trajectories_G=filtered_trajs, + final_rewards_G=filtered_rewards, + metrics_G=filtered_metrics, + ) + ) + + if incomplete_trajs > 0: + logger.warning( + f"Filtered {incomplete_trajs}/{total_trajs} incomplete trajectories " + f"(hit max_tokens limit)" + ) + + return filtered_groups, { + "filter/total_trajectories": total_trajs, + "filter/incomplete_trajectories": incomplete_trajs, + } diff --git a/tinker_cookbook/rl/rollouts.py b/tinker_cookbook/rl/rollouts.py index 6d5fd9bc..bf25f3e4 100644 --- a/tinker_cookbook/rl/rollouts.py +++ b/tinker_cookbook/rl/rollouts.py @@ -24,6 +24,7 @@ async def do_single_rollout(policy: TokenCompleter, env: Env) -> Trajectory: ac=ac_with_logprobs, reward=step_result.reward, episode_done=step_result.episode_done, + is_complete=ac_with_logprobs.is_complete, metrics=step_result.metrics, ) transitions.append(transition) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index c49ae6a9..8bebf98e 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -22,6 +22,7 @@ from tinker_cookbook.rl.data_processing import ( assemble_training_data, compute_advantages, + filter_incomplete_trajectories, remove_constant_reward_groups, ) from tinker_cookbook.rl.metric_util import RLTestSetEvaluator, compute_trajectory_metrics @@ -252,6 +253,7 @@ class Config: enable_trace: bool = False remove_constant_reward_groups: bool = False + filter_incomplete_trajectories: bool = False eval_every: int = 20 # 0 = disabled save_every: int = 20 # 0 = disabled load_checkpoint_path: str | None = None @@ -712,6 +714,7 @@ async def prepare_minibatch( model_name: str, kl_penalty_coef: float, kl_discount_factor: float, + do_filter_incomplete_trajectories: bool = False, ) -> tuple[list[tinker.Datum], dict[str, Any]]: """Converts the trajectories into a minibatch, and provides metrics about the minibatch""" @@ -720,6 +723,14 @@ async def prepare_minibatch( taglist_P = [env_group_builder.logging_tags() for env_group_builder in env_group_builders_P] metrics.update(compute_trajectory_metrics(trajectory_groups_P, taglist_P)) + # Filter incomplete trajectories (hit max_tokens instead of stop sequence) + if do_filter_incomplete_trajectories: + trajectory_groups_P, filter_stats = filter_incomplete_trajectories(trajectory_groups_P) + metrics.update(filter_stats) + if not trajectory_groups_P: + logger.warning("All trajectories were incomplete (hit max_tokens), skipping batch") + return [], metrics + # Print up to two trajectory groups for traj_group in trajectory_groups_P[:2]: print_group(traj_group, tokenizer) @@ -846,9 +857,14 @@ async def do_train_step_streaming_and_get_sampling_client( model_name=cfg.model_name, kl_penalty_coef=cfg.kl_penalty_coef, kl_discount_factor=cfg.kl_discount_factor, + do_filter_incomplete_trajectories=cfg.filter_incomplete_trajectories, ) metrics.update(prepare_minibatch_metrics) + # Skip if all trajectories were filtered + if not data_D: + continue + # Accumulate gradients across multiple minibatches with timed( f"train/forward_backward_substep_{i_substep}_minibatch_{i_minibatch}", metrics @@ -914,9 +930,15 @@ async def do_train_step_and_get_sampling_client( model_name=cfg.model_name, kl_penalty_coef=cfg.kl_penalty_coef, kl_discount_factor=cfg.kl_discount_factor, + do_filter_incomplete_trajectories=cfg.filter_incomplete_trajectories, ) metrics.update(prepare_minibatch_metrics) + # Handle case where all trajectories were filtered + if not data_D: + sampling_client = await training_client.save_weights_and_get_sampling_client_async() + return sampling_client, metrics + with timed("train", metrics): training_logprobs_D = await train_step( data_D, diff --git a/tinker_cookbook/rl/types.py b/tinker_cookbook/rl/types.py index c80ad870..f817bcc0 100644 --- a/tinker_cookbook/rl/types.py +++ b/tinker_cookbook/rl/types.py @@ -32,6 +32,7 @@ class Transition: ac: TokensWithLogprobs reward: float episode_done: bool + is_complete: bool = True # True if action hit stop sequence, False if hit max_tokens metrics: Metrics = field(default_factory=dict)