Skip to content
Merged
4 changes: 4 additions & 0 deletions docs/source/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@
## BEMACallback

[[autodoc]] BEMACallback

## WeaveCallback

[[autodoc]] WeaveCallback
16 changes: 14 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@
"XPOConfig",
"XPOTrainer",
],
"trainer.callbacks": ["BEMACallback", "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.callbacks": [
"BEMACallback",
"MergeModelCallback",
"RichProgressCallback",
"SyncRefModelCallback",
"WeaveCallback",
],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
}

Expand Down Expand Up @@ -170,7 +176,13 @@
XPOConfig,
XPOTrainer,
)
from .trainer.callbacks import BEMACallback, MergeModelCallback, RichProgressCallback, SyncRefModelCallback
from .trainer.callbacks import (
BEMACallback,
MergeModelCallback,
RichProgressCallback,
SyncRefModelCallback,
WeaveCallback,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config

else:
Expand Down
15 changes: 10 additions & 5 deletions trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
# Use same as transformers.utils.import_utils
_deepspeed_available = _is_package_available("deepspeed")
_fastapi_available = _is_package_available("fastapi")
_is_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
_joblib_available = _is_package_available("joblib")
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
_llm_blender_available = _is_package_available("llm_blender")
_mergekit_available = _is_package_available("mergekit")
_pydantic_available = _is_package_available("pydantic")
Expand All @@ -36,7 +37,7 @@
_uvicorn_available = _is_package_available("uvicorn")
_vllm_available = _is_package_available("vllm")
_vllm_ascend_available = _is_package_available("vllm_ascend")
_joblib_available = _is_package_available("joblib")
_weave_available = _is_package_available("weave")


def is_deepspeed_available() -> bool:
Expand All @@ -47,8 +48,12 @@ def is_fastapi_available() -> bool:
return _fastapi_available


def is_joblib_available() -> bool:
return _joblib_available


def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSION) -> bool:
return _is_liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version)
return _liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version)


def is_llm_blender_available() -> bool:
Expand Down Expand Up @@ -83,8 +88,8 @@ def is_vllm_ascend_available() -> bool:
return _vllm_ascend_available


def is_joblib_available() -> bool:
return _joblib_available
def is_weave_available() -> bool:
return _weave_available


class _LazyModule(ModuleType):
Expand Down
2 changes: 2 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"MergeModelCallback",
"RichProgressCallback",
"SyncRefModelCallback",
"WeaveCallback",
"WinRateCallback",
],
"cpo_config": ["CPOConfig"],
Expand Down Expand Up @@ -85,6 +86,7 @@
MergeModelCallback,
RichProgressCallback,
SyncRefModelCallback,
WeaveCallback,
WinRateCallback,
)
from .cpo_config import CPOConfig
Expand Down
240 changes: 235 additions & 5 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from transformers.utils import is_rich_available

from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_mergekit_available
from ..import_utils import is_mergekit_available, is_weave_available
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
Expand All @@ -51,6 +51,11 @@
if is_wandb_available():
import wandb

if is_weave_available():
import weave
from weave import EvaluationLogger
from weave.trace.context import weave_client_context


# Logger for module-level logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -335,8 +340,6 @@ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control:
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb

if wandb.run is not None:
df = _win_rate_completions_df(
state=state,
Expand Down Expand Up @@ -398,8 +401,6 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra
self.trainer.log({"eval_win_rate": win_rate})

if "wandb" in args.report_to:
import wandb

if wandb.run is not None:
df = _win_rate_completions_df(
state=state,
Expand Down Expand Up @@ -514,6 +515,235 @@ def on_step_end(self, args, state, control, **kwargs):
self._last_logged_step = state.global_step


class WeaveCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that logs traces and evaluations to W&B Weave. The callback uses
https://weave-docs.wandb.ai/guides/evaluation/evaluation_logger/ to log traces and evaluations at each evaluation
step.

Supports two modes based on the `scorers` parameter:
- **Tracing Mode** (when scorers=None): Logs predictions for data exploration and analysis
- **Evaluation Mode** (when scorers provided): Logs predictions with scoring and summary metrics

Both modes use Weave's EvaluationLogger for structured, consistent data logging.

The callback logs data during evaluation phases (`on_evaluate`) rather than training steps, making it more
efficient and semantically correct. It gracefully handles missing weave installation by logging warnings and
skipping weave-specific functionality. It also checks for existing weave clients before initializing new ones.

Usage:
```python
# Tracing mode (just log predictions)
trainer = DPOTrainer(...)
weave_callback = WeaveTraceCallback(trainer=trainer) # project_name optional
trainer.add_callback(weave_callback)

# Or specify a project name
weave_callback = WeaveTraceCallback(trainer=trainer, project_name="my-llm-training")
trainer.add_callback(weave_callback)


# Evaluation mode (log predictions + scores + summary)
def accuracy_scorer(prompt: str, completion: str) -> float:
# Your scoring logic here (metadata available via eval_attributes)
return score


weave_callback = WeaveTraceCallback(
trainer=trainer,
project_name="my-llm-training", # optional and needed only if weave client is not initialized
scorers={"accuracy": accuracy_scorer},
)
trainer.add_callback(weave_callback)
```

Args:
trainer (`Trainer`):
Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"`
column containing the prompts for generating completions.
project_name (`str`, *optional*):
Name of the Weave project where data will be logged. If not provided, will try to use existing weave client
or fall back to the active wandb run's project name. Raises an error if none of these are available.
scorers (`dict[str, Callable]`, *optional*):
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
generation_config (`GenerationConfig`, *optional*):
Generation config to use for generating completions.
num_prompts (`int` or `None`, *optional*):
Number of prompts to generate completions for. If not provided, defaults to the number of examples in the
evaluation dataset.
dataset_name (`str`, *optional*, defaults to `"eval_dataset"`):
Name for the dataset metadata in Weave.
model_name (`str`, *optional*):
Name for the model metadata in Weave. If not provided, attempts to extract from model config.
"""

def __init__(
self,
trainer: Trainer,
project_name: Optional[str] = None,
scorers: Optional[dict[str, callable]] = None,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
dataset_name: str = "eval_dataset",
model_name: Optional[str] = None,
):
self.trainer = trainer
self.project_name = project_name
self.scorers = scorers or {}
self.generation_config = generation_config
self.dataset_name = dataset_name
self.model_name = model_name
self._last_logged_step = -1
self._weave_initialized = False
self._eval_logger = None

if self.trainer.eval_dataset is None:
raise ValueError("Trainer must have an evaluation dataset to use the WeaveCallback.")
else:
self.eval_dataset = self.trainer.eval_dataset

if num_prompts is not None:
self.eval_dataset = self.eval_dataset.select(range(num_prompts))

def _initialize_weave(self):
"""Initialize Weave and EvaluationLogger if not already initialized."""
if not self._weave_initialized:
if not is_weave_available():
logger.warning("Weave is not available. Please install weave to enable logging: `pip install weave`")
return

if wc := weave_client_context.get_weave_client():
self._weave_client = wc
else:
if self.project_name is None:
if is_wandb_available():
if wandb.run is not None:
self.project_name = wandb.run.entity + "/" + wandb.run.project
logger.info(f"Using project name from active wandb run: {self.project_name}")

if self.project_name is None:
raise ValueError(
"No existing Weave client found and no project_name provided. "
"Please either initialize weave with `weave.init('project-name')`, "
"provide a project_name to the `WeaveTraceCallback`, "
"or ensure an active wandb run exists."
)

self._weave_client = weave.init(self.project_name)
logger.info(f"Initialized Weave with project: {self.project_name}")

if self.model_name is None:
self.model_name = getattr(self.trainer.model_wrapped.config, "_name_or_path", "unknown_model")

self._EvaluationLogger = EvaluationLogger

self._weave_initialized = True

@property
def is_evaluation_mode(self) -> bool:
"""True if scorers are provided (evaluation mode), False for tracing mode."""
return bool(self.scorers)

def on_train_begin(self, args, state, control, **kwargs):
"""Initialize Weave when training begins."""
self._initialize_weave()

def on_evaluate(self, args, state, control, **kwargs):
if state.global_step == self._last_logged_step:
return

self._initialize_weave()

if not self._weave_initialized:
logger.debug("Weave not initialized, skipping logging")
return

tokenizer = kwargs["processing_class"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator
model = self.trainer.model_wrapped

with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts:
prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts]

completions = _generate_completions(
prompts=prompts,
model=model,
tokenizer=tokenizer,
accelerator=accelerator,
generation_config=self.generation_config,
batch_size=args.per_device_eval_batch_size,
)

all_prompts = gather_object(prompts)
all_completions = gather_object(completions)

if self.trainer.accelerator.is_main_process:
eval_attributes = {
"training_step": state.global_step,
"model_name": self.model_name,
"generation_config": (self.generation_config.to_dict() if self.generation_config else None),
}

eval_logger = self._EvaluationLogger(
model=self.model_name,
dataset=self.dataset_name,
eval_attributes=eval_attributes,
)

successful_predictions = 0
total_score_values = {} # For summary statistics

for prompt, completion in zip(all_prompts, all_completions):
try:
pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion)

if self.is_evaluation_mode:
for scorer_name, scorer_func in self.scorers.items():
try:
score = scorer_func(prompt, completion)
pred_logger.log_score(scorer=scorer_name, score=score)

if scorer_name not in total_score_values:
total_score_values[scorer_name] = []
total_score_values[scorer_name].append(score)

except Exception as scorer_e:
logger.warning(f"Failed to apply scorer '{scorer_name}': {scorer_e}")

pred_logger.finish()
successful_predictions += 1

except Exception as pred_e:
logger.warning(f"Failed to log prediction for prompt: {pred_e}")
# Continue with other predictions even if one fails

if self.is_evaluation_mode and total_score_values:
try:
summary_stats = {
"total_predictions": len(all_prompts),
"successful_predictions": successful_predictions,
}

for scorer_name, scores in total_score_values.items():
if scores: # Only if we have valid scores
summary_stats[f"avg_{scorer_name}"] = sum(scores) / len(scores)

eval_logger.log_summary(summary_stats)

except Exception as summary_e:
logger.warning(f"Failed to log summary: {summary_e}")
else:
try:
eval_logger.finish()
except Exception as finish_e:
logger.warning(f"Failed to finish evaluation logger: {finish_e}")

self._last_logged_step = state.global_step


class MergeModelCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
Expand Down
Loading