From 150a412b46ed3f642e815d11d686ef6510864c71 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Wed, 10 Sep 2025 10:18:31 +0530 Subject: [PATCH 1/5] feat: add weave callback for tracing completions. --- trl/import_utils.py | 5 + trl/trainer/callbacks.py | 263 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 1 deletion(-) diff --git a/trl/import_utils.py b/trl/import_utils.py index dd594e36ef6..56a14653484 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -38,6 +38,7 @@ _vllm_available = _is_package_available("vllm") _vllm_ascend_available = _is_package_available("vllm_ascend") _joblib_available = _is_package_available("joblib") +_weave_available = _is_package_available("weave") def is_deepspeed_available() -> bool: @@ -92,6 +93,10 @@ def is_joblib_available() -> bool: return _joblib_available +def is_weave_available() -> bool: + return _weave_available + + class _LazyModule(ModuleType): """ Module class that surfaces all objects but only performs associated imports when the objects are requested. diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index c6500102997..17fafef7b43 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -35,7 +35,7 @@ from transformers.utils import is_rich_available from ..data_utils import maybe_apply_chat_template -from ..import_utils import is_mergekit_available +from ..import_utils import is_mergekit_available, is_weave_available from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge @@ -51,6 +51,9 @@ if is_wandb_available(): import wandb +if is_weave_available(): + import weave + # Logger for module-level logging logger = logging.getLogger(__name__) @@ -514,6 +517,264 @@ def on_step_end(self, args, state, control, **kwargs): self._last_logged_step = state.global_step +class WeaveTraceCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that logs completions to W&B Weave using manual call tracking with parent-child relationships. + + This callback creates hierarchical traces for model evaluation: + - **Parent Call**: Represents the overall evaluation batch with metadata + - **Child Calls**: Individual sub-traces for each prompt-completion pair + + This structure provides better granularity and organization in the Weave UI, allowing you to drill down + from batch-level evaluation metrics to individual generation traces. + + Usage: + ```python + trainer = DPOTrainer(...) + weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") + trainer.add_callback(weave_callback) + ``` + + Args: + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. + project_name (`str`): + The name of the Weave project where traces will be logged. + generation_config (`GenerationConfig`, *optional*): + The generation config to use for generating completions. + num_prompts (`int` or `None`, *optional*): + The number of prompts to generate completions for. If not provided, defaults to the number of examples in + the evaluation dataset. + freq (`int` or `None`, *optional*): + The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. + """ + + def __init__( + self, + trainer: Trainer, + project_name: str, + generation_config: Optional[GenerationConfig] = None, + num_prompts: Optional[int] = None, + freq: Optional[int] = None, + ): + if not is_weave_available(): + raise ImportError( + "WeaveTraceCallback requires the `weave` package. To install, run `pip install weave`." + ) + + self.trainer = trainer + self.project_name = project_name + self.generation_config = generation_config + self.freq = freq + self._last_logged_step = -1 + self._weave_initialized = False + self._weave_client = None + + if self.trainer.eval_dataset is None: + raise ValueError( + "Trainer must have an evaluation dataset to use the WeaveTraceCallback." + ) + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def _initialize_weave(self): + """Initialize Weave if not already initialized.""" + if not self._weave_initialized: + import weave + + self._weave_client = weave.init(self.project_name) + self._weave_initialized = True + logger.info(f"Initialized Weave with project: {self.project_name}") + + def _generate_traced_completions( + self, + prompts: list[str], + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + accelerator: Accelerator, + generation_config: Optional[GenerationConfig], + batch_size: int, + training_step: int, + model_name: str, + ) -> dict: + """ + Generate completions with manual Weave call tracking using parent-child relationships. + + Creates a parent call for the evaluation batch and child calls for each prompt-completion pair. + + Args: + prompts: List of input prompts + model: The model to use for generation + tokenizer: The tokenizer to use + accelerator: Accelerator for distributed training + generation_config: Generation configuration + batch_size: Batch size for generation + training_step: Current training step + model_name: Name of the model being trained + + Returns: + Dict containing completions and metadata + """ + # Prepare inputs for the parent call (batch evaluation) + parent_inputs = { + "training_step": training_step, + "model_name": model_name, + "num_prompts": len(prompts), + "batch_size": batch_size, + "generation_config": ( + generation_config.to_dict() if generation_config else None + ), + } + + # Create parent call for the evaluation batch + parent_call = self._weave_client.create_call( + op="evaluate_model_completions", inputs=parent_inputs + ) + + child_call_ids = [] + successful_generations = 0 + + try: + # Do the actual completion generation (batch processing for efficiency) + completions = _generate_completions( + prompts=prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=generation_config, + batch_size=batch_size, + ) + + # Create child calls for each prompt-completion pair + for prompt, completion in zip(prompts, completions): + try: + # Create child call with parent relationship + child_inputs = { + "prompt": prompt, + "training_step": training_step, + "model_name": model_name, + "generation_config": ( + generation_config.to_dict() if generation_config else None + ), + } + + # Create child call with parent relationship + child_call = self._weave_client.create_call( + op="generate_single_completion", + inputs=child_inputs, + parent=parent_call, + ) + + # Prepare child output with completion details + child_output = { + "completion": completion, + "prompt_length": len(prompt) if prompt else 0, + "completion_length": len(completion) if completion else 0, + } + + # Finish child call successfully + self._weave_client.finish_call(child_call, output=child_output) + child_call_ids.append(child_call.id) + successful_generations += 1 + + except Exception as child_e: + logger.warning(f"Failed to create child call for prompt: {child_e}") + # Continue with other child calls even if one fails + + # Prepare parent output with summary information + parent_output = { + "total_prompts": len(prompts), + "successful_generations": successful_generations, + "evaluation_step": training_step, + "child_call_ids": child_call_ids, + "model_name": model_name, + } + + # End the parent call successfully + self._weave_client.finish_call(parent_call, output=parent_output) + + return { + "completions": completions, + "num_prompts": len(prompts), + "successful_generations": successful_generations, + "training_step": training_step, + "model_name": model_name, + "parent_call_id": parent_call.id, + "child_call_ids": child_call_ids, + } + + except Exception as e: + # End the parent call with exception + parent_output = { + "total_prompts": len(prompts), + "successful_generations": successful_generations, + "evaluation_step": training_step, + "error": str(e), + "child_call_ids": child_call_ids, + } + self._weave_client.finish_call(parent_call, output=parent_output) + logger.error(f"Error during traced completion generation: {e}") + raise + + def on_train_begin(self, args, state, control, **kwargs): + """Initialize Weave when training begins.""" + self._initialize_weave() + + def on_step_end(self, args, state, control, **kwargs): + # Only log once per step (this method may be called multiple times) + if state.global_step == self._last_logged_step: + return + + # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) + freq = self.freq or state.eval_steps + if state.global_step % freq != 0: + return + + # Ensure Weave is initialized + self._initialize_weave() + + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + + # Get model name for tracing metadata + model_name = getattr(model.config, "_name_or_path", "unknown_model") + + with accelerator.split_between_processes( + self.eval_dataset["prompt"] + ) as prompts: + prompts = [ + maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] + for prompt in prompts + ] + + # Call the manual traced completion generator + result = self._generate_traced_completions( + prompts=prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + training_step=state.global_step, + model_name=model_name, + ) + + logger.info( + f"Logged evaluation trace to Weave at step {state.global_step}: " + f"{result['num_prompts']} prompts, {result['successful_generations']} successful generations, " + f"parent_call_id={result['parent_call_id']}, {len(result['child_call_ids'])} child traces" + ) + + # Save the last logged step, so we don't log the same completions multiple times + self._last_logged_step = state.global_step + + class MergeModelCallback(TrainerCallback): r""" A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based From bb93db8497bd8453afa66d717df7304c4d51330b Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Mon, 15 Sep 2025 18:45:45 +0530 Subject: [PATCH 2/5] feat: add tracing and evaluation options to the weave callback. --- trl/trainer/__init__.py | 2 + trl/trainer/callbacks.py | 340 ++++++++++++++++++++------------------- 2 files changed, 180 insertions(+), 162 deletions(-) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 51d4e1dfb2a..9bc332768ad 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -28,6 +28,7 @@ "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback", + "WeaveCallback", "WinRateCallback", ], "cpo_config": ["CPOConfig"], @@ -99,6 +100,7 @@ MergeModelCallback, RichProgressCallback, SyncRefModelCallback, + WeaveCallback, WinRateCallback, ) from .cpo_config import CPOConfig diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 17fafef7b43..980a7686c16 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -53,6 +53,7 @@ if is_weave_available(): import weave + from weave.trace.context import weave_client_context # Logger for module-level logging @@ -519,19 +520,37 @@ def on_step_end(self, args, state, control, **kwargs): class WeaveTraceCallback(TrainerCallback): r""" - A [`~transformers.TrainerCallback`] that logs completions to W&B Weave using manual call tracking with parent-child relationships. + A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. + The callback uses https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation step. - This callback creates hierarchical traces for model evaluation: - - **Parent Call**: Represents the overall evaluation batch with metadata - - **Child Calls**: Individual sub-traces for each prompt-completion pair + Supports two modes based on the `scorers` parameter: + - **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis + - **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics - This structure provides better granularity and organization in the Weave UI, allowing you to drill down - from batch-level evaluation metrics to individual generation traces. + Both modes use Weave's EvaluationLogger for structured, consistent data logging. + + The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, + making it more efficient and semantically correct. It gracefully handles missing weave + installation by logging warnings and skipping weave-specific functionality. It also checks + for existing weave clients before initializing new ones. Usage: ```python + # Tracing mode (just log predictions) trainer = DPOTrainer(...) - weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") + weave_callback = WeaveCallback(trainer=trainer, project_name="my-llm-training") + trainer.add_callback(weave_callback) + + # Evaluation mode (log predictions + scores + summary) + def accuracy_scorer(prompt: str, completion: str) -> float: + # Your scoring logic here (metadata available via eval_attributes) + return score + + weave_callback = WeaveCallback( + trainer=trainer, + project_name="my-llm-training", + scorers={"accuracy": accuracy_scorer} + ) trainer.add_callback(weave_callback) ``` @@ -540,40 +559,46 @@ class WeaveTraceCallback(TrainerCallback): Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` column containing the prompts for generating completions. project_name (`str`): - The name of the Weave project where traces will be logged. + The name of the Weave project where data will be logged. We default to using the wandb/weave project if not specified. + scorers (`Dict[str, Callable]`, *optional*): + Dictionary mapping scorer names to scorer functions. If None, operates in tracing mode (predictions only). + If provided, operates in evaluation mode (predictions + scores + summary). + Scorer functions should have signature: `scorer(prompt: str, completion: str) -> Union[float, int]` generation_config (`GenerationConfig`, *optional*): The generation config to use for generating completions. num_prompts (`int` or `None`, *optional*): The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. - freq (`int` or `None`, *optional*): - The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. + dataset_name (`str`, *optional*): + Name for the dataset metadata in Weave. Defaults to "eval_dataset". + model_name (`str`, *optional*): + Name for the model metadata in Weave. If not provided, attempts to extract from model config. """ def __init__( self, trainer: Trainer, project_name: str, + scorers: Optional[dict[str, callable]] = None, generation_config: Optional[GenerationConfig] = None, num_prompts: Optional[int] = None, - freq: Optional[int] = None, + dataset_name: Optional[str] = None, + model_name: Optional[str] = None, ): - if not is_weave_available(): - raise ImportError( - "WeaveTraceCallback requires the `weave` package. To install, run `pip install weave`." - ) self.trainer = trainer self.project_name = project_name + self.scorers = scorers or {} self.generation_config = generation_config - self.freq = freq + self.dataset_name = dataset_name or "eval_dataset" + self.model_name = model_name self._last_logged_step = -1 self._weave_initialized = False - self._weave_client = None + self._eval_logger = None if self.trainer.eval_dataset is None: raise ValueError( - "Trainer must have an evaluation dataset to use the WeaveTraceCallback." + "Trainer must have an evaluation dataset to use the WeaveCallback." ) else: self.eval_dataset = self.trainer.eval_dataset @@ -582,169 +607,72 @@ def __init__( self.eval_dataset = self.eval_dataset.select(range(num_prompts)) def _initialize_weave(self): - """Initialize Weave if not already initialized.""" + """Initialize Weave and EvaluationLogger if not already initialized.""" if not self._weave_initialized: - import weave + if not is_weave_available(): + logger.warning( + "Weave is not available. Please install weave to enable logging: `pip install weave`" + ) + return - self._weave_client = weave.init(self.project_name) - self._weave_initialized = True - logger.info(f"Initialized Weave with project: {self.project_name}") + # Check if weave client is already initialized + if wc := weave_client_context.get_weave_client(): + self._weave_client = wc + logger.info("Using existing Weave client") + else: + # Initialize new weave client + import weave - def _generate_traced_completions( - self, - prompts: list[str], - model: PreTrainedModel, - tokenizer: PreTrainedTokenizerBase, - accelerator: Accelerator, - generation_config: Optional[GenerationConfig], - batch_size: int, - training_step: int, - model_name: str, - ) -> dict: - """ - Generate completions with manual Weave call tracking using parent-child relationships. - - Creates a parent call for the evaluation batch and child calls for each prompt-completion pair. - - Args: - prompts: List of input prompts - model: The model to use for generation - tokenizer: The tokenizer to use - accelerator: Accelerator for distributed training - generation_config: Generation configuration - batch_size: Batch size for generation - training_step: Current training step - model_name: Name of the model being trained - - Returns: - Dict containing completions and metadata - """ - # Prepare inputs for the parent call (batch evaluation) - parent_inputs = { - "training_step": training_step, - "model_name": model_name, - "num_prompts": len(prompts), - "batch_size": batch_size, - "generation_config": ( - generation_config.to_dict() if generation_config else None - ), - } - - # Create parent call for the evaluation batch - parent_call = self._weave_client.create_call( - op="evaluate_model_completions", inputs=parent_inputs - ) - - child_call_ids = [] - successful_generations = 0 - - try: - # Do the actual completion generation (batch processing for efficiency) - completions = _generate_completions( - prompts=prompts, - model=model, - tokenizer=tokenizer, - accelerator=accelerator, - generation_config=generation_config, - batch_size=batch_size, - ) + self._weave_client = weave.init(self.project_name) + logger.info(f"Initialized Weave with project: {self.project_name}") - # Create child calls for each prompt-completion pair - for prompt, completion in zip(prompts, completions): + # Get model name for metadata + if self.model_name is None: try: - # Create child call with parent relationship - child_inputs = { - "prompt": prompt, - "training_step": training_step, - "model_name": model_name, - "generation_config": ( - generation_config.to_dict() if generation_config else None - ), - } - - # Create child call with parent relationship - child_call = self._weave_client.create_call( - op="generate_single_completion", - inputs=child_inputs, - parent=parent_call, + self.model_name = getattr( + self.trainer.model_wrapped.config, + "_name_or_path", + "unknown_model", ) + except: + self.model_name = "unknown_model" - # Prepare child output with completion details - child_output = { - "completion": completion, - "prompt_length": len(prompt) if prompt else 0, - "completion_length": len(completion) if completion else 0, - } + # EvaluationLogger will be created per evaluation step + # Store weave module for later use + from weave import EvaluationLogger - # Finish child call successfully - self._weave_client.finish_call(child_call, output=child_output) - child_call_ids.append(child_call.id) - successful_generations += 1 - - except Exception as child_e: - logger.warning(f"Failed to create child call for prompt: {child_e}") - # Continue with other child calls even if one fails - - # Prepare parent output with summary information - parent_output = { - "total_prompts": len(prompts), - "successful_generations": successful_generations, - "evaluation_step": training_step, - "child_call_ids": child_call_ids, - "model_name": model_name, - } + self._EvaluationLogger = EvaluationLogger - # End the parent call successfully - self._weave_client.finish_call(parent_call, output=parent_output) - - return { - "completions": completions, - "num_prompts": len(prompts), - "successful_generations": successful_generations, - "training_step": training_step, - "model_name": model_name, - "parent_call_id": parent_call.id, - "child_call_ids": child_call_ids, - } + self._weave_initialized = True - except Exception as e: - # End the parent call with exception - parent_output = { - "total_prompts": len(prompts), - "successful_generations": successful_generations, - "evaluation_step": training_step, - "error": str(e), - "child_call_ids": child_call_ids, - } - self._weave_client.finish_call(parent_call, output=parent_output) - logger.error(f"Error during traced completion generation: {e}") - raise + @property + def is_evaluation_mode(self) -> bool: + """True if scorers are provided (evaluation mode), False for tracing mode.""" + return bool(self.scorers) def on_train_begin(self, args, state, control, **kwargs): """Initialize Weave when training begins.""" self._initialize_weave() - def on_step_end(self, args, state, control, **kwargs): - # Only log once per step (this method may be called multiple times) + def on_evaluate(self, args, state, control, **kwargs): + # Only log once per evaluation (this method may be called multiple times) if state.global_step == self._last_logged_step: return - # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) - freq = self.freq or state.eval_steps - if state.global_step % freq != 0: - return - # Ensure Weave is initialized self._initialize_weave() + # If weave initialization failed, skip logging + if not self._weave_initialized: + logger.debug("Weave not initialized, skipping logging") + return + tokenizer = kwargs["processing_class"] tokenizer.padding_side = "left" accelerator = self.trainer.accelerator model = self.trainer.model_wrapped - # Get model name for tracing metadata - model_name = getattr(model.config, "_name_or_path", "unknown_model") - + # All processes generate completions for their subset with accelerator.split_between_processes( self.eval_dataset["prompt"] ) as prompts: @@ -753,22 +681,110 @@ def on_step_end(self, args, state, control, **kwargs): for prompt in prompts ] - # Call the manual traced completion generator - result = self._generate_traced_completions( + # Generate completions using existing utility + completions = _generate_completions( prompts=prompts, model=model, tokenizer=tokenizer, accelerator=accelerator, generation_config=self.generation_config, batch_size=args.per_device_eval_batch_size, - training_step=state.global_step, - model_name=model_name, + ) + + # Gather all prompts and completions from all processes + all_prompts = gather_object(prompts) + all_completions = gather_object(completions) + + # Only the main process does the logging + if self.trainer.accelerator.is_main_process: + # Create a new EvaluationLogger for this evaluation step with metadata + eval_attributes = { + "training_step": state.global_step, + "model_name": self.model_name, + "generation_config": ( + self.generation_config.to_dict() if self.generation_config else None + ), + } + + eval_logger = self._EvaluationLogger( + model=self.model_name, + dataset=self.dataset_name, + eval_attributes=eval_attributes, + ) + + # Log all gathered predictions using EvaluationLogger + successful_predictions = 0 + total_score_values = {} # For summary statistics + + for prompt, completion in zip(all_prompts, all_completions): + try: + # Log prediction to Weave + pred_logger = eval_logger.log_prediction( + inputs={"prompt": prompt}, output=completion + ) + + # Apply scorers if in evaluation mode + if self.is_evaluation_mode: + for scorer_name, scorer_func in self.scorers.items(): + try: + # Scorer context no longer needs metadata (it's in eval_attributes) + score = scorer_func(prompt, completion) + pred_logger.log_score(scorer=scorer_name, score=score) + + # Collect scores for summary + if scorer_name not in total_score_values: + total_score_values[scorer_name] = [] + total_score_values[scorer_name].append(score) + + except Exception as scorer_e: + logger.warning( + f"Failed to apply scorer '{scorer_name}': {scorer_e}" + ) + + # Finish prediction logging + pred_logger.finish() + successful_predictions += 1 + + except Exception as pred_e: + logger.warning(f"Failed to log prediction for prompt: {pred_e}") + # Continue with other predictions even if one fails + + # Log summary if in evaluation mode, otherwise finish the logger + if self.is_evaluation_mode and total_score_values: + try: + # Calculate summary statistics + summary_stats = { + "total_predictions": len(all_prompts), + "successful_predictions": successful_predictions, + } + + # Add average scores for each scorer + for scorer_name, scores in total_score_values.items(): + if scores: # Only if we have valid scores + summary_stats[f"avg_{scorer_name}"] = sum(scores) / len( + scores + ) + + eval_logger.log_summary(summary_stats) + + except Exception as summary_e: + logger.warning(f"Failed to log summary: {summary_e}") + else: + # In tracing mode (no scorers), we need to properly finish the evaluation logger + try: + eval_logger.finish() + except Exception as finish_e: + logger.warning(f"Failed to finish evaluation logger: {finish_e}") + + # Log success message + mode = "evaluation" if self.is_evaluation_mode else "tracing" + scorer_info = ( + f" with {len(self.scorers)} scorers" if self.is_evaluation_mode else "" ) logger.info( - f"Logged evaluation trace to Weave at step {state.global_step}: " - f"{result['num_prompts']} prompts, {result['successful_generations']} successful generations, " - f"parent_call_id={result['parent_call_id']}, {len(result['child_call_ids'])} child traces" + f"Logged {mode} data to Weave at step {state.global_step}: " + f"{len(all_prompts)} prompts, {successful_predictions} successful predictions{scorer_info}" ) # Save the last logged step, so we don't log the same completions multiple times From 395bb0a001e6a388d7f3c7ca00e4767d32238bf5 Mon Sep 17 00:00:00 2001 From: Bharat Ramanathan Date: Mon, 15 Sep 2025 19:03:16 +0530 Subject: [PATCH 3/5] feat: make project name optional in weave callback --- trl/trainer/callbacks.py | 78 +++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 46 deletions(-) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 980a7686c16..f15e2d1a850 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -54,6 +54,7 @@ if is_weave_available(): import weave from weave.trace.context import weave_client_context + from weave import EvaluationLogger # Logger for module-level logging @@ -518,7 +519,7 @@ def on_step_end(self, args, state, control, **kwargs): self._last_logged_step = state.global_step -class WeaveTraceCallback(TrainerCallback): +class WeaveCallback(TrainerCallback): r""" A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation step. @@ -538,7 +539,11 @@ class WeaveTraceCallback(TrainerCallback): ```python # Tracing mode (just log predictions) trainer = DPOTrainer(...) - weave_callback = WeaveCallback(trainer=trainer, project_name="my-llm-training") + weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional + trainer.add_callback(weave_callback) + + # Or specify a project name + weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") trainer.add_callback(weave_callback) # Evaluation mode (log predictions + scores + summary) @@ -546,9 +551,9 @@ def accuracy_scorer(prompt: str, completion: str) -> float: # Your scoring logic here (metadata available via eval_attributes) return score - weave_callback = WeaveCallback( + weave_callback = WeaveTraceCallback( trainer=trainer, - project_name="my-llm-training", + project_name="my-llm-training", # optional and needed only if weave client is not initialized scorers={"accuracy": accuracy_scorer} ) trainer.add_callback(weave_callback) @@ -558,8 +563,9 @@ def accuracy_scorer(prompt: str, completion: str) -> float: trainer (`Trainer`): Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` column containing the prompts for generating completions. - project_name (`str`): - The name of the Weave project where data will be logged. We default to using the wandb/weave project if not specified. + project_name (`str`, *optional*): + The name of the Weave project where data will be logged. If not provided, will try to use existing weave client + or fall back to the active wandb run's project name. Raises an error if none of these are available. scorers (`Dict[str, Callable]`, *optional*): Dictionary mapping scorer names to scorer functions. If None, operates in tracing mode (predictions only). If provided, operates in evaluation mode (predictions + scores + summary). @@ -578,7 +584,7 @@ def accuracy_scorer(prompt: str, completion: str) -> float: def __init__( self, trainer: Trainer, - project_name: str, + project_name: Optional[str] = None, scorers: Optional[dict[str, callable]] = None, generation_config: Optional[GenerationConfig] = None, num_prompts: Optional[int] = None, @@ -615,18 +621,32 @@ def _initialize_weave(self): ) return - # Check if weave client is already initialized if wc := weave_client_context.get_weave_client(): self._weave_client = wc - logger.info("Using existing Weave client") else: - # Initialize new weave client - import weave + if self.project_name is None: + if is_wandb_available(): + import wandb + + if wandb.run is not None: + self.project_name = ( + wandb.run.entity + "/" + wandb.run.project + ) + logger.info( + f"Using project name from active wandb run: {self.project_name}" + ) + + if self.project_name is None: + raise ValueError( + "No existing Weave client found and no project_name provided. " + "Please either initialize weave with `weave.init('project-name')`, " + "provide a project_name to the `WeaveTraceCallback`, " + "or ensure an active wandb run exists." + ) self._weave_client = weave.init(self.project_name) logger.info(f"Initialized Weave with project: {self.project_name}") - # Get model name for metadata if self.model_name is None: try: self.model_name = getattr( @@ -637,10 +657,6 @@ def _initialize_weave(self): except: self.model_name = "unknown_model" - # EvaluationLogger will be created per evaluation step - # Store weave module for later use - from weave import EvaluationLogger - self._EvaluationLogger = EvaluationLogger self._weave_initialized = True @@ -655,14 +671,11 @@ def on_train_begin(self, args, state, control, **kwargs): self._initialize_weave() def on_evaluate(self, args, state, control, **kwargs): - # Only log once per evaluation (this method may be called multiple times) if state.global_step == self._last_logged_step: return - # Ensure Weave is initialized self._initialize_weave() - # If weave initialization failed, skip logging if not self._weave_initialized: logger.debug("Weave not initialized, skipping logging") return @@ -672,7 +685,6 @@ def on_evaluate(self, args, state, control, **kwargs): accelerator = self.trainer.accelerator model = self.trainer.model_wrapped - # All processes generate completions for their subset with accelerator.split_between_processes( self.eval_dataset["prompt"] ) as prompts: @@ -681,7 +693,6 @@ def on_evaluate(self, args, state, control, **kwargs): for prompt in prompts ] - # Generate completions using existing utility completions = _generate_completions( prompts=prompts, model=model, @@ -691,13 +702,10 @@ def on_evaluate(self, args, state, control, **kwargs): batch_size=args.per_device_eval_batch_size, ) - # Gather all prompts and completions from all processes all_prompts = gather_object(prompts) all_completions = gather_object(completions) - # Only the main process does the logging if self.trainer.accelerator.is_main_process: - # Create a new EvaluationLogger for this evaluation step with metadata eval_attributes = { "training_step": state.global_step, "model_name": self.model_name, @@ -712,26 +720,21 @@ def on_evaluate(self, args, state, control, **kwargs): eval_attributes=eval_attributes, ) - # Log all gathered predictions using EvaluationLogger successful_predictions = 0 total_score_values = {} # For summary statistics for prompt, completion in zip(all_prompts, all_completions): try: - # Log prediction to Weave pred_logger = eval_logger.log_prediction( inputs={"prompt": prompt}, output=completion ) - # Apply scorers if in evaluation mode if self.is_evaluation_mode: for scorer_name, scorer_func in self.scorers.items(): try: - # Scorer context no longer needs metadata (it's in eval_attributes) score = scorer_func(prompt, completion) pred_logger.log_score(scorer=scorer_name, score=score) - # Collect scores for summary if scorer_name not in total_score_values: total_score_values[scorer_name] = [] total_score_values[scorer_name].append(score) @@ -741,7 +744,6 @@ def on_evaluate(self, args, state, control, **kwargs): f"Failed to apply scorer '{scorer_name}': {scorer_e}" ) - # Finish prediction logging pred_logger.finish() successful_predictions += 1 @@ -749,16 +751,13 @@ def on_evaluate(self, args, state, control, **kwargs): logger.warning(f"Failed to log prediction for prompt: {pred_e}") # Continue with other predictions even if one fails - # Log summary if in evaluation mode, otherwise finish the logger if self.is_evaluation_mode and total_score_values: try: - # Calculate summary statistics summary_stats = { "total_predictions": len(all_prompts), "successful_predictions": successful_predictions, } - # Add average scores for each scorer for scorer_name, scores in total_score_values.items(): if scores: # Only if we have valid scores summary_stats[f"avg_{scorer_name}"] = sum(scores) / len( @@ -770,24 +769,11 @@ def on_evaluate(self, args, state, control, **kwargs): except Exception as summary_e: logger.warning(f"Failed to log summary: {summary_e}") else: - # In tracing mode (no scorers), we need to properly finish the evaluation logger try: eval_logger.finish() except Exception as finish_e: logger.warning(f"Failed to finish evaluation logger: {finish_e}") - # Log success message - mode = "evaluation" if self.is_evaluation_mode else "tracing" - scorer_info = ( - f" with {len(self.scorers)} scorers" if self.is_evaluation_mode else "" - ) - - logger.info( - f"Logged {mode} data to Weave at step {state.global_step}: " - f"{len(all_prompts)} prompts, {successful_predictions} successful predictions{scorer_info}" - ) - - # Save the last logged step, so we don't log the same completions multiple times self._last_logged_step = state.global_step From cf477ffc2403b74e1ae4b216c08375d53ef6f661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 17 Sep 2025 23:38:32 +0000 Subject: [PATCH 4/5] doc style and nits --- docs/source/callbacks.md | 4 ++ trl/import_utils.py | 14 +++--- trl/trainer/callbacks.py | 99 ++++++++++++++-------------------------- 3 files changed, 44 insertions(+), 73 deletions(-) diff --git a/docs/source/callbacks.md b/docs/source/callbacks.md index cbafdc8a7b5..b89de277309 100644 --- a/docs/source/callbacks.md +++ b/docs/source/callbacks.md @@ -23,3 +23,7 @@ ## BEMACallback [[autodoc]] BEMACallback + +## WeaveCallback + +[[autodoc]] WeaveCallback diff --git a/trl/import_utils.py b/trl/import_utils.py index e14b05bc9b9..e495a845dae 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -27,7 +27,8 @@ # Use same as transformers.utils.import_utils _deepspeed_available = _is_package_available("deepspeed") _fastapi_available = _is_package_available("fastapi") -_is_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) +_joblib_available = _is_package_available("joblib") +_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) _llm_blender_available = _is_package_available("llm_blender") _mergekit_available = _is_package_available("mergekit") _pydantic_available = _is_package_available("pydantic") @@ -36,7 +37,6 @@ _uvicorn_available = _is_package_available("uvicorn") _vllm_available = _is_package_available("vllm") _vllm_ascend_available = _is_package_available("vllm_ascend") -_joblib_available = _is_package_available("joblib") _weave_available = _is_package_available("weave") @@ -48,8 +48,12 @@ def is_fastapi_available() -> bool: return _fastapi_available +def is_joblib_available() -> bool: + return _joblib_available + + def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSION) -> bool: - return _is_liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version) + return _liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version) def is_llm_blender_available() -> bool: @@ -84,10 +88,6 @@ def is_vllm_ascend_available() -> bool: return _vllm_ascend_available -def is_joblib_available() -> bool: - return _joblib_available - - def is_weave_available() -> bool: return _weave_available diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index c3fba51a766..68fb6f97b72 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -53,8 +53,8 @@ if is_weave_available(): import weave - from weave.trace.context import weave_client_context from weave import EvaluationLogger + from weave.trace.context import weave_client_context # Logger for module-level logging @@ -340,8 +340,6 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: self.trainer.log({"eval_win_rate": win_rate}) if "wandb" in args.report_to: - import wandb - if wandb.run is not None: df = _win_rate_completions_df( state=state, @@ -403,8 +401,6 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra self.trainer.log({"eval_win_rate": win_rate}) if "wandb" in args.report_to: - import wandb - if wandb.run is not None: df = _win_rate_completions_df( state=state, @@ -521,8 +517,9 @@ def on_step_end(self, args, state, control, **kwargs): class WeaveCallback(TrainerCallback): r""" - A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. - The callback uses https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation step. + A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses + https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation + step. Supports two modes based on the `scorers` parameter: - **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis @@ -530,10 +527,9 @@ class WeaveCallback(TrainerCallback): Both modes use Weave's EvaluationLogger for structured, consistent data logging. - The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, - making it more efficient and semantically correct. It gracefully handles missing weave - installation by logging warnings and skipping weave-specific functionality. It also checks - for existing weave clients before initializing new ones. + The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more + efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and + skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones. Usage: ```python @@ -546,15 +542,17 @@ class WeaveCallback(TrainerCallback): weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training") trainer.add_callback(weave_callback) + # Evaluation mode (log predictions + scores + summary) def accuracy_scorer(prompt: str, completion: str) -> float: # Your scoring logic here (metadata available via eval_attributes) return score + weave_callback = WeaveTraceCallback( trainer=trainer, project_name="my-llm-training", # optional and needed only if weave client is not initialized - scorers={"accuracy": accuracy_scorer} + scorers={"accuracy": accuracy_scorer}, ) trainer.add_callback(weave_callback) ``` @@ -564,19 +562,19 @@ def accuracy_scorer(prompt: str, completion: str) -> float: Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` column containing the prompts for generating completions. project_name (`str`, *optional*): - The name of the Weave project where data will be logged. If not provided, will try to use existing weave client + Name of the Weave project where data will be logged. If not provided, will try to use existing weave client or fall back to the active wandb run's project name. Raises an error if none of these are available. - scorers (`Dict[str, Callable]`, *optional*): - Dictionary mapping scorer names to scorer functions. If None, operates in tracing mode (predictions only). - If provided, operates in evaluation mode (predictions + scores + summary). - Scorer functions should have signature: `scorer(prompt: str, completion: str) -> Union[float, int]` + scorers (`dict[str, Callable]`, *optional*): + Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions + only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should + have signature: `scorer(prompt: str, completion: str) -> Union[float, int]` generation_config (`GenerationConfig`, *optional*): - The generation config to use for generating completions. + Generation config to use for generating completions. num_prompts (`int` or `None`, *optional*): - The number of prompts to generate completions for. If not provided, defaults to the number of examples in - the evaluation dataset. - dataset_name (`str`, *optional*): - Name for the dataset metadata in Weave. Defaults to "eval_dataset". + Number of prompts to generate completions for. If not provided, defaults to the number of examples in the + evaluation dataset. + dataset_name (`str`, *optional*, defaults to `"eval_dataset"`): + Name for the dataset metadata in Weave. model_name (`str`, *optional*): Name for the model metadata in Weave. If not provided, attempts to extract from model config. """ @@ -588,24 +586,21 @@ def __init__( scorers: Optional[dict[str, callable]] = None, generation_config: Optional[GenerationConfig] = None, num_prompts: Optional[int] = None, - dataset_name: Optional[str] = None, + dataset_name: str = "eval_dataset", model_name: Optional[str] = None, ): - self.trainer = trainer self.project_name = project_name self.scorers = scorers or {} self.generation_config = generation_config - self.dataset_name = dataset_name or "eval_dataset" + self.dataset_name = dataset_name self.model_name = model_name self._last_logged_step = -1 self._weave_initialized = False self._eval_logger = None if self.trainer.eval_dataset is None: - raise ValueError( - "Trainer must have an evaluation dataset to use the WeaveCallback." - ) + raise ValueError("Trainer must have an evaluation dataset to use the WeaveCallback.") else: self.eval_dataset = self.trainer.eval_dataset @@ -616,9 +611,7 @@ def _initialize_weave(self): """Initialize Weave and EvaluationLogger if not already initialized.""" if not self._weave_initialized: if not is_weave_available(): - logger.warning( - "Weave is not available. Please install weave to enable logging: `pip install weave`" - ) + logger.warning("Weave is not available. Please install weave to enable logging: `pip install weave`") return if wc := weave_client_context.get_weave_client(): @@ -626,15 +619,9 @@ def _initialize_weave(self): else: if self.project_name is None: if is_wandb_available(): - import wandb - if wandb.run is not None: - self.project_name = ( - wandb.run.entity + "/" + wandb.run.project - ) - logger.info( - f"Using project name from active wandb run: {self.project_name}" - ) + self.project_name = wandb.run.entity + "/" + wandb.run.project + logger.info(f"Using project name from active wandb run: {self.project_name}") if self.project_name is None: raise ValueError( @@ -648,14 +635,7 @@ def _initialize_weave(self): logger.info(f"Initialized Weave with project: {self.project_name}") if self.model_name is None: - try: - self.model_name = getattr( - self.trainer.model_wrapped.config, - "_name_or_path", - "unknown_model", - ) - except: - self.model_name = "unknown_model" + self.model_name = getattr(self.trainer.model_wrapped.config, "_name_or_path", "unknown_model") self._EvaluationLogger = EvaluationLogger @@ -685,13 +665,8 @@ def on_evaluate(self, args, state, control, **kwargs): accelerator = self.trainer.accelerator model = self.trainer.model_wrapped - with accelerator.split_between_processes( - self.eval_dataset["prompt"] - ) as prompts: - prompts = [ - maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] - for prompt in prompts - ] + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts] completions = _generate_completions( prompts=prompts, @@ -709,9 +684,7 @@ def on_evaluate(self, args, state, control, **kwargs): eval_attributes = { "training_step": state.global_step, "model_name": self.model_name, - "generation_config": ( - self.generation_config.to_dict() if self.generation_config else None - ), + "generation_config": (self.generation_config.to_dict() if self.generation_config else None), } eval_logger = self._EvaluationLogger( @@ -725,9 +698,7 @@ def on_evaluate(self, args, state, control, **kwargs): for prompt, completion in zip(all_prompts, all_completions): try: - pred_logger = eval_logger.log_prediction( - inputs={"prompt": prompt}, output=completion - ) + pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion) if self.is_evaluation_mode: for scorer_name, scorer_func in self.scorers.items(): @@ -740,9 +711,7 @@ def on_evaluate(self, args, state, control, **kwargs): total_score_values[scorer_name].append(score) except Exception as scorer_e: - logger.warning( - f"Failed to apply scorer '{scorer_name}': {scorer_e}" - ) + logger.warning(f"Failed to apply scorer '{scorer_name}': {scorer_e}") pred_logger.finish() successful_predictions += 1 @@ -760,9 +729,7 @@ def on_evaluate(self, args, state, control, **kwargs): for scorer_name, scores in total_score_values.items(): if scores: # Only if we have valid scores - summary_stats[f"avg_{scorer_name}"] = sum(scores) / len( - scores - ) + summary_stats[f"avg_{scorer_name}"] = sum(scores) / len(scores) eval_logger.log_summary(summary_stats) From d9e808ff3bdec7a727ba3e8d8460b6cb9a138d3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 17 Sep 2025 23:53:36 +0000 Subject: [PATCH 5/5] top import --- trl/__init__.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/trl/__init__.py b/trl/__init__.py index fafe7461a4e..0fb584f41ab 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -95,7 +95,13 @@ "XPOConfig", "XPOTrainer", ], - "trainer.callbacks": ["BEMACallback", "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], + "trainer.callbacks": [ + "BEMACallback", + "MergeModelCallback", + "RichProgressCallback", + "SyncRefModelCallback", + "WeaveCallback", + ], "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], } @@ -170,7 +176,13 @@ XPOConfig, XPOTrainer, ) - from .trainer.callbacks import BEMACallback, MergeModelCallback, RichProgressCallback, SyncRefModelCallback + from .trainer.callbacks import ( + BEMACallback, + MergeModelCallback, + RichProgressCallback, + SyncRefModelCallback, + WeaveCallback, + ) from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config else: