From 0195a06f5f0a9c8d170b63378f04122e46aab5d7 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 17 Sep 2025 21:55:15 +0200 Subject: [PATCH 01/58] add first version to deal with image on verifiers, dealt with environment and grpo trainer, mainly the get lgps function --- notes.md | 8 + verifiers/envs/environment.py | 24 ++- verifiers/trainers/grpo_trainer.py | 241 +++++++++++++++++++++-------- 3 files changed, 207 insertions(+), 66 deletions(-) create mode 100644 notes.md diff --git a/notes.md b/notes.md new file mode 100644 index 000000000..8db212e2c --- /dev/null +++ b/notes.md @@ -0,0 +1,8 @@ +La génération se fait avec OpenAI dans les envs qui est déjà robuste aux images dans prompt : ok pour ça + +Il faut juste check dans le trainer je pense + +Il faut que dans l'input il y ait images en optionnel qui sera en input de la processing classe ? + + +Est-ce que dans _prepare_inputs il faut gérer les images ? ou seulement dans _get_per_token_logps ? \ No newline at end of file diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 2e79d2117..36d5a9438 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,10 +4,11 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import TYPE_CHECKING, Literal - +from typing import TYPE_CHECKING, Literal, Union +from transformers import ProcessorMixin, AutoProcessor, AutoConfig from datasets import Dataset from openai import AsyncOpenAI, OpenAI +from transformers.tokenization_utils_base import PreTrainedTokenizerBase from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric @@ -665,7 +666,7 @@ def process_chat_format_vllm( prompt: list[ChatMessage], completion: list[ChatMessage], state: State, - processing_class: "PreTrainedTokenizerBase", + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], mask_env_responses: bool = False, ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: """ @@ -772,7 +773,22 @@ def process_completion_format_vllm( idx = response_start_idx + len(response_text) assert idx == len(completion), "Completion not fully consumed" - prompt_ids: list[int] = processing_class.encode(prompt) + + # Ici on ajoute les images en kwargs + + kwargs = {} + has_images = "image" in inputs[0] + if has_images: + images = [example.get("image") for example in inputs] + kwargs = {"images": [[img] for img in images]} + if isinstance(prompt, list): + for message in prompt: + if isinstance(message, dict) and message.get("role") == "user": + if isinstance(message.get("content"), str): + message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] + break + + prompt_ids: list[int] = processing_class(prompt,**kwargs) rollout_consumed = prompt prompt_mask: list[int] = [0] * len(prompt_ids) completion_ids: list[int] = [] diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 792cc16db..928350208 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -17,6 +17,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers import ProcessorMixin, AutoProcessor, AutoConfig from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import seed_worker @@ -210,7 +211,37 @@ def shuffle_tensor_dict( for key, tensor in tensor_dict.items() } +def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]: + """ + Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in + `batch["image_grid_thw"]`, while keeping other entries unchanged. + """ + if "image_grid_thw" not in batch or "pixel_values" not in batch: + return batch + + lengths = batch["image_grid_thw"].prod(dim=1).tolist() # [batch_size] + pixel_values = batch["pixel_values"] # [total, feature_dim] + + if sum(lengths) != pixel_values.size(0): + raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") + + split_values = list(torch.split(batch["pixel_values"], lengths, dim=0)) + return {**batch, "pixel_values": split_values} + +def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]: + """ + Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` + back into a single tensor along the first dimension. + """ + pixel_values = batch.get("pixel_values") + + if isinstance(pixel_values, list): + merged = torch.cat(pixel_values, dim=0) + return {**batch, "pixel_values": merged} + else: + return batch + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. @@ -247,7 +278,7 @@ def __init__( model: PreTrainedModel, env: Environment, args: GRPOConfig, - processing_class: PreTrainedTokenizerBase, + processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None, callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[ Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR] @@ -274,9 +305,24 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True - # Tokenizer pad token - if processing_class.pad_token is None: # type: ignore - processing_class.pad_token = processing_class.eos_token # type: ignore + # Handle pad token for processors or tokenizers + if isinstance(processing_class, ProcessorMixin): + tokenizer = processing_class.tokenizer + elif isinstance(processing_class, PreTrainedTokenizerBase): + tokenizer = processing_class + else: + raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + self.pad_token = tokenizer.pad_token + self.pad_token_id = tokenizer.pad_token_id + self.eos_token_id = tokenizer.eos_token_id + self.image_token = getattr(processing_class, "image_token", None) + self.image_token_id = getattr(processing_class, "image_token_id", None) + self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None) + self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None) # Training arguments self.per_device_train_batch_size = args.per_device_train_batch_size @@ -475,9 +521,9 @@ def data_collator(features): elif is_deepspeed_zero3_enabled(): model_id = model.config._name_or_path model_init_kwargs = {"torch_dtype": "auto"} - self.ref_model = AutoModelForCausalLM.from_pretrained( - model_id, **model_init_kwargs - ) + config = AutoConfig.from_pretrained(model_id) + architecture = getattr(transformers, config.architectures[0]) + self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) elif is_peft_model(model): # If PEFT is used, the reference model is not needed since the adapter can be disabled # to revert to the initial model. @@ -510,7 +556,8 @@ def data_collator(features): * args.per_device_train_batch_size * args.gradient_accumulation_steps ) - self._textual_logs = { + self._logs = { + "image": deque(maxlen=maxlen), "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), "rewards": defaultdict(lambda: deque(maxlen=maxlen)), @@ -699,24 +746,62 @@ def _inner_training_loop(self, *args, **kwargs): self.async_generator.stop() self._async_started = False + def _get_last_hidden_state( - self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None + self, + unwrapped_model, + input_ids, + attention_mask, + logits_to_keep, + pixel_values=None, + image_grid_thw=None, + pixel_attention_mask=None, + image_sizes=None, ): if is_peft_model(unwrapped_model): unwrapped_model = unwrapped_model.base_model.model - last_hidden_state = unwrapped_model.model( - input_ids=input_ids, attention_mask=attention_mask - ).last_hidden_state + + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + + # For Qwen models: + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw + # For Gemma, SmolVLM2, LLaVa-Next etc.: + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + # For SmolVLM2 + if pixel_attention_mask is not None: + model_inputs["pixel_attention_mask"] = pixel_attention_mask + # For LLaVa-Next + if image_sizes is not None: + model_inputs["image_sizes"] = image_sizes + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state + # Exclude the last value: it corresponds to the next token pred last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - if logits_to_keep is not None: - last_hidden_state = last_hidden_state[ - :, -logits_to_keep:, : - ] # (B, logits_to_keep, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) return last_hidden_state - + # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps( - self, model, input_ids, attention_mask, logits_to_keep, batch_size=None + self, + model, + input_ids, + attention_mask, + logits_to_keep, + batch_size=None, + compute_entropy=False, + pixel_values=None, + image_grid_thw=None, ) -> torch.Tensor: batch_size = batch_size or input_ids.size( 0 @@ -725,18 +810,27 @@ def _get_per_token_logps( for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] - logits = model( - input_ids=input_ids_batch, - attention_mask=attention_mask_batch, - logits_to_keep=logits_to_keep + 1, - ).logits - logits = logits[ - :, :-1, : - ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - input_ids_batch = input_ids_batch[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - logits = logits[:, -logits_to_keep:] + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) + model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + + # Only add logits_to_keep if the model supports it + if "logits_to_keep" in self.model_kwarg_keys: + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] # (B, L-1, H) + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) # Divide logits by sampling temperature. # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details logits = logits / self.temperature @@ -994,7 +1088,9 @@ def _prepare_inputs( # type: ignore ) break batch_offset = batch_id - batch_id_to_retrieve - all_prompts, all_answers, all_tasks, all_infos = ( + + # En gros ici il faut recup les images je pense si elles existe + all_prompts, all_answers, all_tasks, all_infos, all_images = ( self._gather_batch_data(batch_offset) ) if not all_prompts: @@ -1002,17 +1098,22 @@ def _prepare_inputs( # type: ignore f"No prompts for batch {batch_id}, stopping batch generation" ) break - + + env_inputs = { + "prompt": all_prompts, + "answer": all_answers, + "task": all_tasks, + "info": all_infos, + } + + if len(all_images) > 0 : + env_inputs["image"] = all_images + # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( batch_id=batch_id, - env_inputs={ - "prompt": all_prompts, - "answer": all_answers, - "task": all_tasks, - "info": all_infos, - }, + env_inputs=env_inputs, processing_class=self.processing_class, mask_env_responses=self.mask_env_responses, max_seq_len=self.max_seq_len or -1, @@ -1190,17 +1291,20 @@ def compute_loss( # type: ignore completion_mask = attention_mask[:, 1:] logits_to_keep = completion_mask.size(1) per_token_logps = self._get_per_token_logps( - model, input_ids, attention_mask, logits_to_keep + model, + input_ids, + attention_mask, + logits_to_keep, + compute_entropy=self.top_entropy_quantile < 1.0, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), ) # Compute the loss advantages = inputs["advantages"] # When using num_iterations == 1, old_per_token_logps == per_token_logps, # so we can skip it's computation (see _generate_and_score_completions) and use per_token_logps.detach() instead. - old_per_token_logps = ( - per_token_logps.detach() - if inputs["old_per_token_logps"] is None - else inputs["old_per_token_logps"] - ) + old_per_token_logps = inputs.get("old_per_token_logps") + old_per_token_logps = per_token_logps.detach() if old_per_token_logps is None else old_per_token_logps coef_1 = torch.exp(per_token_logps - old_per_token_logps) coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) @@ -1221,12 +1325,12 @@ def compute_loss( # type: ignore with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep + self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep + self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) @@ -1308,7 +1412,7 @@ def _sanitize_tool_calls( msg.pop("tool_call_id") return completion - def evaluate( + def evaluate( # TODO : check for images self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval", **kwargs ): """ @@ -1458,11 +1562,12 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._metrics[mode].clear() if self.accelerator.is_main_process and self.log_completions: - if len(self._textual_logs["prompt"]) > 0: + if len(self._logs["prompt"]) > 0: print_prompt_completions_sample( - self._textual_logs["prompt"], - self._textual_logs["completion"], - self._textual_logs["rewards"]["reward"], + self._logs["prompt"], + self._logs["completion"], + self._logs["rewards"], + self._logs["advantages"], self.state.global_step, ) @@ -1475,14 +1580,24 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: table = { "step": [str(self.state.global_step)] - * len(self._textual_logs["prompt"]), - "prompt": list(self._textual_logs["prompt"]), + * len(self._logs["prompt"]), + "prompt": list(self._logs["prompt"]), "completion": [ self._sanitize_tool_calls(c) - for c in self._textual_logs["completion"] + for c in self._logs["completion"] ], - **{k: list(v) for k, v in self._textual_logs["rewards"].items()}, + **{k: list(v) for k, v in self._logs["rewards"].items()}, } + + if self._logs["image"]: + table["image"] = [] + for img in self._logs["image"]: + if img is not None: + # Convert images to wandb Image objects for proper visualization + table["image"].append(wandb.Image(img)) + else: + table["image"].append(None) + if len(table["prompt"]) > 0: df = pd.DataFrame(table) if self.wandb_log_unique_prompts: @@ -1490,10 +1605,12 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: wandb.log({"completions": wandb.Table(dataframe=df)}) # Clear the textual logs after logging - self._textual_logs["prompt"].clear() - self._textual_logs["completion"].clear() - for key in self._textual_logs["rewards"]: - self._textual_logs["rewards"][key].clear() + self._logs["prompt"].clear() + self._logs["completion"].clear() + if self._logs["image"] : + self._logs["image"].clear() + for key in self._logs["rewards"]: + self._logs["rewards"][key].clear() def _log_reward_metrics_primary( self, @@ -1535,13 +1652,13 @@ def _log_textual_data_primary( Log textual data for wandb (PRIMARY PROCESS ONLY). This logs the full batch of prompts, completions, and rewards. """ - self._textual_logs["prompt"].extend(all_prompts) - self._textual_logs["completion"].extend(all_completions) + self._logs["prompt"].extend(all_prompts) + self._logs["completion"].extend(all_completions) # Log all reward scores - both individual functions and consolidated for reward_key in all_reward_dict: reward_values = all_reward_dict[reward_key] - self._textual_logs["rewards"][reward_key].extend( + self._logs["rewards"][reward_key].extend( reward_values.tolist() if isinstance(reward_values, torch.Tensor) else reward_values From 7ad272a0a5f7bb152c2e0a593c0b6ee7f1df5c3d Mon Sep 17 00:00:00 2001 From: ulrick Date: Sun, 21 Sep 2025 12:24:29 +0000 Subject: [PATCH 02/58] fix some issues with trainer --- verifiers/trainers/grpo_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 928350208..a4dc623ea 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -398,6 +398,8 @@ def __init__( eval_dataset = env.get_eval_dataset() + print("data", train_dataset) + print(train_dataset.column_names) if "prompt" not in train_dataset.column_names: raise ValueError("Train dataset must contain a 'prompt' column") if "answer" not in train_dataset.column_names: From eb31d6e905e22f94b957fa61af82946889034db5 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 21 Sep 2025 15:03:16 +0200 Subject: [PATCH 03/58] fix filter by prompt length to deal with images --- verifiers/trainers/grpo_trainer.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 928350208..fc5ba06d1 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -439,18 +439,31 @@ def __init__( ) max_length = self.max_prompt_length # Capture for closure - def filter_by_prompt_length(example, processing_class): + def filter_by_prompt_length(example, processing_class, max_length): prompt = example["prompt"] - # Tokenize prompt to check length + if isinstance(prompt, list): - # Chat format prompt_text = processing_class.apply_chat_template( prompt, tokenize=False, add_generation_prompt=True ) else: - # Completion format prompt_text = prompt - prompt_ids = processing_class.encode(prompt_text) # type: ignore + + if isinstance(processing_class, PreTrainedTokenizerBase): + prompt_ids = processing_class.encode(prompt_text) + elif isinstance(processing_class, ProcessorMixin): + kwargs = {} + if "image" in example: + kwargs["images"] = [example["image"]] + inputs = processing_class( + text=prompt_text, + return_tensors="pt", + add_special_tokens=False, + **kwargs, + ) + prompt_ids = inputs["input_ids"][0].tolist() + else: + raise ValueError(f"Unsupported processing class: {type(processing_class)}") return len(prompt_ids) <= max_length original_size = len(train_dataset) @@ -1090,7 +1103,7 @@ def _prepare_inputs( # type: ignore batch_offset = batch_id - batch_id_to_retrieve # En gros ici il faut recup les images je pense si elles existe - all_prompts, all_answers, all_tasks, all_infos, all_images = ( + all_prompts, all_answers, all_tasks, all_infos = ( self._gather_batch_data(batch_offset) ) if not all_prompts: @@ -1106,8 +1119,8 @@ def _prepare_inputs( # type: ignore "info": all_infos, } - if len(all_images) > 0 : - env_inputs["image"] = all_images + # if len(all_images) > 0 : + # env_inputs["image"] = all_images # Submit batch (main process only) if self.accelerator.is_main_process: From 83078740b4dbf7c2882058c6bf1571661404f930 Mon Sep 17 00:00:00 2001 From: ulrick Date: Sun, 21 Sep 2025 13:48:10 +0000 Subject: [PATCH 04/58] WIP : fix the dataset filtering with images, working. Arrived at the vllm processing --- verifiers/trainers/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 62b6faaf7..cbe85bcfd 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -441,9 +441,8 @@ def __init__( ) max_length = self.max_prompt_length # Capture for closure - def filter_by_prompt_length(example, processing_class, max_length): + def filter_by_prompt_length(example, processing_class): prompt = example["prompt"] - if isinstance(prompt, list): prompt_text = processing_class.apply_chat_template( prompt, tokenize=False, add_generation_prompt=True @@ -457,6 +456,7 @@ def filter_by_prompt_length(example, processing_class, max_length): kwargs = {} if "image" in example: kwargs["images"] = [example["image"]] + inputs = processing_class( text=prompt_text, return_tensors="pt", From 1a5a66232da0a507c56919addfd45b61e64efeb3 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 21 Sep 2025 18:54:01 +0200 Subject: [PATCH 05/58] add batch treatment with base64 for openAI format with vllm --- verifiers/trainers/grpo_trainer.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 62b6faaf7..a101c958d 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -24,6 +24,8 @@ from trl.models import create_reference_model, prepare_deepspeed from trl.trainer.callbacks import SyncRefModelCallback from trl.trainer.utils import disable_dropout_in_model, pad, selective_log_softmax +import base64 +from io import BytesIO from verifiers import Environment from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest @@ -271,6 +273,14 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) return torch.max(tensor[~torch.isnan(tensor)]) +def pil_to_base64_url(pil_image) -> str: + """ + Convert a PIL image to a base64 URL string suitable for OpenAI/vLLM messages. + """ + buffered = BytesIO() + pil_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return f"data:image/png;base64,{img_str}" class GRPOTrainer(Trainer): def __init__( @@ -1029,9 +1039,23 @@ def _gather_batch_data( if isinstance(batch, dict): batch = [batch] - + + # Vérifier si le batch contient "image" si c'est le cas, il faut transformer ça en image url # Gather batch data from all processes - prompts = [x["prompt"] for x in batch] + prompts = [] + for x in batch: + prompt = x["prompt"] + if "image" in x and x["image"] is not None: + img_url = pil_to_base64_url(x["image"]) + for message in prompt: + for content in message.get("content", []): + if content.get("type") == "image": + content["type"] = "image_url" + content["image_url"] = img_url + if "image" in content: + del content["image"] + prompts.append(prompt) + answers = [x["answer"] for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] From 652a7c32d69117f4d26cb848d5e009032c65b20b Mon Sep 17 00:00:00 2001 From: ulrick Date: Sun, 21 Sep 2025 18:51:45 +0000 Subject: [PATCH 06/58] WIP async generation working, issue in processing for now --- verifiers/envs/environment.py | 5 ++- verifiers/trainers/grpo_trainer.py | 48 +++++++++++------------- verifiers/utils/message_utils.py | 60 ++++++++++++++++++------------ 3 files changed, 61 insertions(+), 52 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 36d5a9438..f1dcc2904 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -9,7 +9,7 @@ from datasets import Dataset from openai import AsyncOpenAI, OpenAI from transformers.tokenization_utils_base import PreTrainedTokenizerBase - +import transformers from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric from verifiers.types import ( @@ -225,6 +225,7 @@ async def get_model_response( ): sampling_args.pop("max_completion_tokens") clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None} + try: if message_type == "chat": assert isinstance(prompt, list) @@ -471,6 +472,8 @@ def generate( ) -> GenerateOutputs: if isinstance(client, OpenAI): client = AsyncOpenAI(api_key=client.api_key, base_url=client.base_url) + + print("inputs",inputs) coro = self.a_generate( inputs, client, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 0830a28bd..c29a35d1b 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -26,6 +26,7 @@ from trl.trainer.utils import disable_dropout_in_model, pad, selective_log_softmax import base64 from io import BytesIO +import transformers from verifiers import Environment from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest @@ -1014,55 +1015,48 @@ def _ids_to_tensors( ids = torch.stack(ids, dim=0) mask = torch.stack(mask, dim=0) return {"ids": ids, "mask": mask} - - def _gather_batch_data( - self, batch_offset: int = 0 - ) -> Tuple[List[Any], List[Any], List[Any], List[Any]]: + + def _gather_batch_data(self, batch_offset: int = 0): """ - Gather batch data from all processes. - - Args: - batch_offset: 0 for current batch, >0 for future batches (peek ahead) - - Returns: - Tuple of (all_prompts, all_answers, all_tasks) + Gather batch data from all processes and convert PIL images in prompts to base64 image_url. """ + print("GATHERING BATCH DATA") batches = self._async_dataloader.peek_ahead(batch_offset) - + if batch_offset == 0: batch = batches[0] if batches else None else: batch = batches[batch_offset - 1] if batches else None - + if batch is None: return [], [], [], [] - + if isinstance(batch, dict): batch = [batch] - - # Vérifier si le batch contient "image" si c'est le cas, il faut transformer ça en image url - # Gather batch data from all processes + prompts = [] for x in batch: prompt = x["prompt"] - if "image" in x and x["image"] is not None: - img_url = pil_to_base64_url(x["image"]) - for message in prompt: - for content in message.get("content", []): - if content.get("type") == "image": - content["type"] = "image_url" - content["image_url"] = img_url - if "image" in content: - del content["image"] + for message in prompt: + for content in message.get("content", []): + if content.get("type") == "image" : + img_url = pil_to_base64_url(x["image"]) + content.clear() + content.update({ + "type": "image_url", + "image_url": {"url":img_url} + }) prompts.append(prompt) - + answers = [x["answer"] for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] + all_prompts = gather_object(prompts) all_answers = gather_object(answers) all_tasks = gather_object(tasks) all_infos = gather_object(infos) + return all_prompts, all_answers, all_tasks, all_infos def _prepare_inputs( # type: ignore diff --git a/verifiers/utils/message_utils.py b/verifiers/utils/message_utils.py index afa730f69..a6ee286a4 100644 --- a/verifiers/utils/message_utils.py +++ b/verifiers/utils/message_utils.py @@ -44,40 +44,52 @@ def messages_to_printable(messages: Messages) -> Messages: def cleanup_message(message: ChatMessage) -> ChatMessage: - new_message = {} - new_message["role"] = message["role"] + new_message = { + "role": message["role"], + "content": [] + } + if "tool_calls" in message: new_message["tool_calls"] = message["tool_calls"] - if "tool_call_id" in message: new_message["tool_call_id"] = message["tool_call_id"] - new_message["content"] = [] content = message.get("content") if content is None: return cast(ChatMessage, new_message) + if isinstance(content, str): new_message["content"] = content - else: - for c in content: - new_c = c.copy() - c_dict = dict(c) - if "image_url" in c_dict and "type" in c_dict and c_dict["type"] == "text": - new_c.pop("image_url") - new_message["content"].append(new_c) - elif ( - "image_url" in c_dict - and "type" in c_dict - and c_dict["type"] == "image_url" - ): - new_c.pop("text") - new_message["content"].append(new_c) - elif str(c_dict.get("type", "")).startswith("input_audio"): - # Ensure input_audio content blocks only have the required fields - clean_c = {"type": "input_audio", "input_audio": c_dict.get("input_audio", {})} - new_message["content"].append(clean_c) - else: - new_message["content"].append(new_c) + return cast(ChatMessage, new_message) + + for c in content: + c_dict = dict(c) + c_type = c_dict.get("type") + + if c_type == "text": + if c_dict.get("text") is not None: + new_message["content"].append({ + "type": "text", + "text": c_dict["text"] + }) + + elif c_type == "image_url": + if "image_url" in c_dict: + new_message["content"].append({ + "type": "image_url", + "image_url": c_dict["image_url"] + }) + + elif c_type == "input_audio": + if "input_audio" in c_dict: + new_message["content"].append({ + "type": "input_audio", + "input_audio": c_dict["input_audio"] + }) + + else: + new_message["content"].append(c_dict) + return cast(ChatMessage, new_message) From 1c4ce784c985b5341e8444b04fcd40e69bc5374a Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 21 Sep 2025 20:55:21 +0200 Subject: [PATCH 07/58] fix ruff issues --- verifiers/envs/environment.py | 22 ++++++++-------- verifiers/trainers/grpo_trainer.py | 42 +++++++++++++++--------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index f1dcc2904..0c26a3743 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from typing import TYPE_CHECKING, Literal, Union -from transformers import ProcessorMixin, AutoProcessor, AutoConfig +from transformers import ProcessorMixin from datasets import Dataset from openai import AsyncOpenAI, OpenAI from transformers.tokenization_utils_base import PreTrainedTokenizerBase @@ -780,16 +780,16 @@ def process_completion_format_vllm( # Ici on ajoute les images en kwargs kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - if isinstance(prompt, list): - for message in prompt: - if isinstance(message, dict) and message.get("role") == "user": - if isinstance(message.get("content"), str): - message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] - break + # has_images = "image" in inputs[0] + # if has_images: + # images = [example.get("image") for example in inputs] + # kwargs = {"images": [[img] for img in images]} + # if isinstance(prompt, list): + # for message in prompt: + # if isinstance(message, dict) and message.get("role") == "user": + # if isinstance(message.get("content"), str): + # message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] + # break prompt_ids: list[int] = processing_class(prompt,**kwargs) rollout_consumed = prompt diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c29a35d1b..5c7f79fd1 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -13,11 +13,10 @@ from accelerate.utils import broadcast_object_list, gather_object, is_peft_model from peft import PeftConfig, get_peft_model from torch.utils.data import DataLoader, Sampler -from transformers import AutoModelForCausalLM from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers import ProcessorMixin, AutoProcessor, AutoConfig +from transformers import ProcessorMixin, AutoConfig from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import seed_worker @@ -790,25 +789,26 @@ def _get_last_hidden_state( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - # For Qwen models: - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw - # For Gemma, SmolVLM2, LLaVa-Next etc.: - if pixel_values is not None: - model_inputs["pixel_values"] = pixel_values - # For SmolVLM2 - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask - # For LLaVa-Next - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings + # TODO : check if needed or no + # # For Qwen models: + # if image_grid_thw is not None and pixel_values is not None: + # model_inputs["image_grid_thw"] = image_grid_thw + # # For Gemma, SmolVLM2, LLaVa-Next etc.: + # if pixel_values is not None: + # model_inputs["pixel_values"] = pixel_values + # # For SmolVLM2 + # if pixel_attention_mask is not None: + # model_inputs["pixel_attention_mask"] = pixel_attention_mask + # # For LLaVa-Next + # if image_sizes is not None: + # model_inputs["image_sizes"] = image_sizes + + # # Only add logits_to_keep if the model supports it + # if "logits_to_keep" in self.model_kwarg_keys: + # # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + # model_inputs["logits_to_keep"] = logits_to_keep + 1 + + # model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state # Exclude the last value: it corresponds to the next token pred From 121913f3301c85545669272831d0a876dea98412 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 21 Sep 2025 20:57:57 +0200 Subject: [PATCH 08/58] fix ruff issues --- verifiers/envs/environment.py | 1 - verifiers/trainers/grpo_trainer.py | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 0c26a3743..97fa99557 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -9,7 +9,6 @@ from datasets import Dataset from openai import AsyncOpenAI, OpenAI from transformers.tokenization_utils_base import PreTrainedTokenizerBase -import transformers from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric from verifiers.types import ( diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 5c7f79fd1..7d8e77c42 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -4,7 +4,7 @@ import time from collections import defaultdict, deque from contextlib import nullcontext -from typing import Any, Dict, List, Optional, Sized, Tuple, Union +from typing import Any, Dict, List, Optional, Sized, Union import datasets import numpy as np @@ -838,14 +838,16 @@ def _get_per_token_logps( attention_mask_batch = attention_mask[i : i + batch_size] # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - elif pixel_values is not None: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + + + # TODO : check if needed there or if already robust to VLM + # if image_grid_thw is not None and pixel_values is not None: + # model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] + # start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + # end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + # model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + # elif pixel_values is not None: + # model_inputs["pixel_values"] = pixel_values[start : start + batch_size] # Only add logits_to_keep if the model supports it if "logits_to_keep" in self.model_kwarg_keys: From 7460db2a4a2ae895c50a28ba4e985100390495e8 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 21 Sep 2025 21:24:35 +0200 Subject: [PATCH 09/58] WIP : deal with after rollout : processing the input and output and putting TODO where I think I need to pass vision data --- verifiers/envs/environment.py | 81 ++++++++++++++++++++++++++---- verifiers/trainers/grpo_trainer.py | 2 + 2 files changed, 72 insertions(+), 11 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 97fa99557..fb14a470f 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -29,12 +29,67 @@ ) from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls +import base64 +from io import BytesIO +from typing import List, Dict, Union +from PIL import Image +from transformers import PreTrainedTokenizerBase, ProcessorMixin + if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore PreTrainedTokenizerBase, ) +def _base64_to_pil(data_uri: str) -> Image.Image: + """Convert a base64 data URI (data:image/...;base64,...) to a PIL Image.""" + if not data_uri.startswith("data:image"): + raise ValueError(f"Expected base64 image data URI, got: {data_uri[:30]}") + header, b64data = data_uri.split(",", 1) + image_data = base64.b64decode(b64data) + return Image.open(BytesIO(image_data)).convert("RGB") + + +def encode_chat_with_processor( + conversation: List[Dict], + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + add_generation_prompt: bool = False, + add_special_tokens: bool = False, +) -> List[int]: + """ + Apply chat template and return token IDs, handling both tokenizer and processor. + Supports base64-encoded images in the conversation. + """ + if isinstance(processing_class, PreTrainedTokenizerBase): + return processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + ) + + elif isinstance(processing_class, ProcessorMixin): + prompt_text = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + + images = [] + for msg in conversation: + for c in msg.get("content", []): + if c.get("type") == "image_url": + pil_img = _base64_to_pil(c["image_url"]) + images.append(pil_img) + + inputs = processing_class( + text=[prompt_text], + images=images if images else None, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + return inputs["input_ids"][0].tolist() + else: + raise TypeError(f"Unsupported processing_class: {type(processing_class)}") + class Environment(ABC): """ Base class for all environments. @@ -685,8 +740,9 @@ def process_chat_format_vllm( zipped.append((turn, None)) assert len(responses) == responses_idx, "Responses not fully consumed" assert len(zipped) == len(completion), "Length mismatch" - prompt_ids: list[int] = processing_class.apply_chat_template( - conversation=prompt, # type: ignore + prompt_ids: list[int] = encode_chat_with_processor( + conversation=prompt, + processing_class=processing_class, add_generation_prompt=True, ) messages_consumed = [m for m in prompt] @@ -717,13 +773,15 @@ def process_chat_format_vllm( while j < len(zipped) and zipped[j][0]["role"] != "assistant": consecutive_messages.append(zipped[j][0]) j += 1 - token_prefix: list[int] = processing_class.apply_chat_template( - conversation=messages_consumed # type: ignore + token_prefix: list[int] = encode_chat_with_processor( + conversation=messages_consumed, # type: ignore + processing_class=processing_class, + add_generation_prompt=False, ) - token_prefix_with_turn: list[int] = ( - processing_class.apply_chat_template( - conversation=messages_consumed + consecutive_messages, # type: ignore - ) + token_prefix_with_turn: list[int] = encode_chat_with_processor( + conversation=messages_consumed + consecutive_messages, # type: ignore + processing_class=processing_class, + add_generation_prompt=False, ) assert token_prefix_with_turn[: len(token_prefix)] == token_prefix, ( f"Token prefix mismatch. Token prefix: {token_prefix}, token prefix with turn: {token_prefix_with_turn}" @@ -739,6 +797,7 @@ def process_chat_format_vllm( completion_logprobs.extend(completion_turn_logprobs) messages_consumed.extend(consecutive_messages) i = j + # TODO : do I need to return the image grid and so on ? return ( prompt_ids, prompt_mask, @@ -752,7 +811,7 @@ def process_completion_format_vllm( prompt: str, completion: str, state: State, - processing_class: "PreTrainedTokenizerBase", + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], mask_env_responses: bool = False, ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: """ @@ -775,7 +834,7 @@ def process_completion_format_vllm( idx = response_start_idx + len(response_text) assert idx == len(completion), "Completion not fully consumed" - + # TODO : check if needed # Ici on ajoute les images en kwargs kwargs = {} @@ -843,7 +902,7 @@ def process_env_results_vllm( completions: list[Messages], states: list[State], rewards: list[float], - processing_class: "PreTrainedTokenizerBase", + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], max_seq_len: int = -1, mask_env_responses: bool = False, mask_truncated_completions: bool = False, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 7d8e77c42..b6ab64fbc 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1175,6 +1175,8 @@ def _prepare_inputs( # type: ignore # Now retrieve the batch we need for this step if self.accelerator.is_main_process: # Get batch result + + # TODO : toute le get vllm et préparation sert à arriver ici, est-ce que les grid et pixel values sont nécessaires ? je pense que oui pour la backward batch_result = self.async_generator.get_batch(batch_id_to_retrieve) processed_results = batch_result.processed_results From 5219e08082080a93c0ef0c89277d4fdafe4a90c4 Mon Sep 17 00:00:00 2001 From: BLE Date: Mon, 22 Sep 2025 22:49:53 +0200 Subject: [PATCH 10/58] WIP : pass pixel values and image grid all way long to prepare input, need to check how it goes with text only and need to test it with Qwen 2.5 VL --- verifiers/envs/environment.py | 53 +++++++++++++++++------------- verifiers/trainers/grpo_trainer.py | 25 +++++++++++--- verifiers/types.py | 2 ++ 3 files changed, 54 insertions(+), 26 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index fb14a470f..381a818e2 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -60,10 +60,11 @@ def encode_chat_with_processor( Supports base64-encoded images in the conversation. """ if isinstance(processing_class, PreTrainedTokenizerBase): - return processing_class.apply_chat_template( + prompt_ids : List[int] = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, ) + return prompt_ids,None,None elif isinstance(processing_class, ProcessorMixin): prompt_text = processing_class.apply_chat_template( @@ -85,7 +86,7 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) - return inputs["input_ids"][0].tolist() + return inputs["input_ids"][0].tolist(), inputs["image_grid"][0].tolist(), inputs["pixel_values"][0].tolist() else: raise TypeError(f"Unsupported processing_class: {type(processing_class)}") @@ -740,11 +741,13 @@ def process_chat_format_vllm( zipped.append((turn, None)) assert len(responses) == responses_idx, "Responses not fully consumed" assert len(zipped) == len(completion), "Length mismatch" - prompt_ids: list[int] = encode_chat_with_processor( + + prompt_ids, prompt_image_grid, prompt_pixel_value = encode_chat_with_processor( conversation=prompt, processing_class=processing_class, add_generation_prompt=True, ) + messages_consumed = [m for m in prompt] prompt_mask: list[int] = [0] * len(prompt_ids) completion_ids: list[int] = [] @@ -773,12 +776,12 @@ def process_chat_format_vllm( while j < len(zipped) and zipped[j][0]["role"] != "assistant": consecutive_messages.append(zipped[j][0]) j += 1 - token_prefix: list[int] = encode_chat_with_processor( + token_prefix, token_prefix_image_grid, token_prefix_pixel_values = encode_chat_with_processor( conversation=messages_consumed, # type: ignore processing_class=processing_class, add_generation_prompt=False, ) - token_prefix_with_turn: list[int] = encode_chat_with_processor( + token_prefix_with_turn, token_prefix_with_turn_image_grid,token_prefix_with_turn_pixel_values = encode_chat_with_processor( conversation=messages_consumed + consecutive_messages, # type: ignore processing_class=processing_class, add_generation_prompt=False, @@ -791,16 +794,19 @@ def process_chat_format_vllm( completion_turn_mask = [0] * len(completion_turn_ids) else: completion_turn_mask = [1] * len(completion_turn_ids) + + # TODO : what about images in turn ? for now we consider we hav'nt completion_turn_logprobs = [0.0] * len(completion_turn_ids) completion_ids.extend(completion_turn_ids) completion_mask.extend(completion_turn_mask) completion_logprobs.extend(completion_turn_logprobs) messages_consumed.extend(consecutive_messages) i = j - # TODO : do I need to return the image grid and so on ? return ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -834,27 +840,17 @@ def process_completion_format_vllm( idx = response_start_idx + len(response_text) assert idx == len(completion), "Completion not fully consumed" - # TODO : check if needed - # Ici on ajoute les images en kwargs - - kwargs = {} - # has_images = "image" in inputs[0] - # if has_images: - # images = [example.get("image") for example in inputs] - # kwargs = {"images": [[img] for img in images]} - # if isinstance(prompt, list): - # for message in prompt: - # if isinstance(message, dict) and message.get("role") == "user": - # if isinstance(message.get("content"), str): - # message["content"] = [{"type": "image"}, {"type": "text", "text": message["content"]}] - # break - - prompt_ids: list[int] = processing_class(prompt,**kwargs) + prompt_ids, prompt_image_grid, prompt_pixel_value = encode_chat_with_processor( + conversation=prompt, + processing_class=processing_class, + add_generation_prompt=False, + ) rollout_consumed = prompt prompt_mask: list[int] = [0] * len(prompt_ids) completion_ids: list[int] = [] completion_mask: list[int] = [] completion_logprobs: list[float] = [] + i = 0 while i < len(zipped): text, response = zipped[i] @@ -891,6 +887,8 @@ def process_completion_format_vllm( return ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -915,6 +913,8 @@ def process_env_results_vllm( all_prompt_ids = [] all_prompt_masks = [] + all_prompt_image_grid = [] + all_prompt_pixel_value = [] all_completion_ids = [] all_completion_masks = [] all_completion_logprobs = [] @@ -928,6 +928,8 @@ def process_env_results_vllm( ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -939,6 +941,8 @@ def process_env_results_vllm( ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, @@ -970,6 +974,8 @@ def process_env_results_vllm( ) all_prompt_ids.append(prompt_ids) all_prompt_masks.append(prompt_mask) + all_prompt_image_grid.append(prompt_image_grid) + all_prompt_pixel_value.append(prompt_pixel_value) all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) all_completion_logprobs.append(completion_logprobs) @@ -980,10 +986,13 @@ def process_env_results_vllm( return ProcessedOutputs( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, + image_grid = all_prompt_image_grid, + pixel_value= all_prompt_pixel_value, completion_ids=all_completion_ids, completion_mask=all_completion_masks, completion_logprobs=all_completion_logprobs, rewards=all_rewards, + ) # alias for process_env_results_vllm diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b6ab64fbc..d280933b8 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1124,7 +1124,6 @@ def _prepare_inputs( # type: ignore break batch_offset = batch_id - batch_id_to_retrieve - # En gros ici il faut recup les images je pense si elles existe all_prompts, all_answers, all_tasks, all_infos = ( self._gather_batch_data(batch_offset) ) @@ -1140,9 +1139,6 @@ def _prepare_inputs( # type: ignore "task": all_tasks, "info": all_infos, } - - # if len(all_images) > 0 : - # env_inputs["image"] = all_images # Submit batch (main process only) if self.accelerator.is_main_process: @@ -1190,6 +1186,8 @@ def _prepare_inputs( # type: ignore "all_reward_dict": batch_result.all_reward_dict, "completions": batch_result.completions, "prompts": batch_result.prompts, + "pixel_values":processed_results.get("pixel_values"), + "image_grid_thw":processed_results.get("image_grid_thw"), } else: broadcast_data = None @@ -1219,6 +1217,8 @@ def _prepare_inputs( # type: ignore # Now create tensors only for this process's slice input_ids_list = [] attention_mask_list = [] + pixel_values_list = [] + image_grid_list = [] for i in range(process_slice.start, process_slice.stop): input_ids_list.append( @@ -1235,6 +1235,18 @@ def _prepare_inputs( # type: ignore device=self.accelerator.device, ) ) + pixel_values_list.append( + torch.tensor( + broadcast_data["pixel_values"][i], + device=self.accelerator.device, + ) + ) + image_grid_list.append( + torch.tensor( + broadcast_data["image_grid"][i], + device=self.accelerator.device, + ) + ) input_ids = pad( input_ids_list, @@ -1243,6 +1255,9 @@ def _prepare_inputs( # type: ignore ) # type: ignore attention_mask = pad(attention_mask_list, padding_side="right") # type: ignore + pixel_values = torch.stack(pixel_values_list, dim=0) + image_grid = torch.stack(image_grid_list, dim=0) + # Truncate if needed if self.max_seq_len is not None and input_ids.size(1) > self.max_seq_len: input_ids = input_ids[:, -self.max_seq_len :] @@ -1280,6 +1295,8 @@ def _prepare_inputs( # type: ignore "attention_mask": attention_mask, "old_per_token_logps": None, "advantages": advantages, + "pixel_values": pixel_values, + "image_grid": image_grid } # Shuffle and split for gradient accumulation diff --git a/verifiers/types.py b/verifiers/types.py index bd9b20786..6fee00652 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -88,6 +88,8 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] + image_grid: list[list[int]] + pixel_values: list[list[int]] completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] From eac0ccfb9dc44d9bdb53b5125492948836343a2f Mon Sep 17 00:00:00 2001 From: ulrick Date: Tue, 23 Sep 2025 16:30:18 +0000 Subject: [PATCH 11/58] WIP : fix encoding and treatment, issue with pydantic in ProcessedOutputs --- verifiers/envs/environment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 381a818e2..d6f2b1269 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -77,7 +77,7 @@ def encode_chat_with_processor( for msg in conversation: for c in msg.get("content", []): if c.get("type") == "image_url": - pil_img = _base64_to_pil(c["image_url"]) + pil_img = _base64_to_pil(c["image_url"]["url"]) images.append(pil_img) inputs = processing_class( @@ -86,7 +86,7 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) - return inputs["input_ids"][0].tolist(), inputs["image_grid"][0].tolist(), inputs["pixel_values"][0].tolist() + return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"][0].tolist() else: raise TypeError(f"Unsupported processing_class: {type(processing_class)}") From e000737f091e5ab552ef91d8a6946fb9d1afa6ff Mon Sep 17 00:00:00 2001 From: ulrick Date: Tue, 23 Sep 2025 18:13:00 +0000 Subject: [PATCH 12/58] fix until compute loss issue --- verifiers/envs/environment.py | 4 ++-- verifiers/trainers/grpo_trainer.py | 24 ++++++++++++------------ verifiers/types.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index d6f2b1269..7efc86ccc 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -986,8 +986,8 @@ def process_env_results_vllm( return ProcessedOutputs( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, - image_grid = all_prompt_image_grid, - pixel_value= all_prompt_pixel_value, + image_grid_thw = all_prompt_image_grid, + pixel_values= all_prompt_pixel_value, completion_ids=all_completion_ids, completion_mask=all_completion_masks, completion_logprobs=all_completion_logprobs, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index d280933b8..ffb2354b5 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -236,7 +236,7 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` back into a single tensor along the first dimension. """ - pixel_values = batch.get("pixel_values") + pixel_values = batch.pixel_values if isinstance(pixel_values, list): merged = torch.cat(pixel_values, dim=0) @@ -1186,8 +1186,8 @@ def _prepare_inputs( # type: ignore "all_reward_dict": batch_result.all_reward_dict, "completions": batch_result.completions, "prompts": batch_result.prompts, - "pixel_values":processed_results.get("pixel_values"), - "image_grid_thw":processed_results.get("image_grid_thw"), + "pixel_values":processed_results.pixel_values, + "image_grid_thw":processed_results.image_grid_thw, } else: broadcast_data = None @@ -1243,20 +1243,20 @@ def _prepare_inputs( # type: ignore ) image_grid_list.append( torch.tensor( - broadcast_data["image_grid"][i], + broadcast_data["image_grid_thw"][i], device=self.accelerator.device, ) ) input_ids = pad( input_ids_list, - padding_value=self.processing_class.pad_token_id, # type: ignore + padding_value=self.pad_token_id, # type: ignore padding_side="right", ) # type: ignore attention_mask = pad(attention_mask_list, padding_side="right") # type: ignore pixel_values = torch.stack(pixel_values_list, dim=0) - image_grid = torch.stack(image_grid_list, dim=0) + image_grid_thw = torch.stack(image_grid_list, dim=0) # Truncate if needed if self.max_seq_len is not None and input_ids.size(1) > self.max_seq_len: @@ -1296,7 +1296,7 @@ def _prepare_inputs( # type: ignore "old_per_token_logps": None, "advantages": advantages, "pixel_values": pixel_values, - "image_grid": image_grid + "image_grid_thw": image_grid_thw } # Shuffle and split for gradient accumulation @@ -1350,8 +1350,8 @@ def compute_loss( # type: ignore attention_mask, logits_to_keep, compute_entropy=self.top_entropy_quantile < 1.0, - pixel_values=inputs.get("pixel_values"), - image_grid_thw=inputs.get("image_grid_thw"), + pixel_values=inputs.pixel_values, + image_grid_thw=inputs.image_grid_thw, ) # Compute the loss advantages = inputs["advantages"] @@ -1379,12 +1379,12 @@ def compute_loss( # type: ignore with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), + self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.pixel_values,image_grid_thw=inputs.image_grid_thw, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), + self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.pixel_values,image_grid_thw=inputs.image_grid_thw, ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) @@ -1754,7 +1754,7 @@ def _log_completion_metrics_primary( term_lengths = [] for comp_ids, comp_mask in zip(all_completion_ids, all_completion_mask): has_eos = any( - token == self.processing_class.eos_token_id # type: ignore + token == self.eos_token_id # type: ignore for token, mask in zip(comp_ids, comp_mask) if mask ) diff --git a/verifiers/types.py b/verifiers/types.py index 6fee00652..6b0f883cf 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -88,8 +88,8 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] - image_grid: list[list[int]] - pixel_values: list[list[int]] + image_grid_thw: list[list[int]] + pixel_values: list[list[float]] completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] From f31ddb116d9974b200515557865325de1da0fc3d Mon Sep 17 00:00:00 2001 From: BLE Date: Tue, 23 Sep 2025 20:13:50 +0200 Subject: [PATCH 13/58] fix pixel values --- verifiers/envs/environment.py | 2 +- verifiers/types.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index d6f2b1269..0285cbbdd 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -987,7 +987,7 @@ def process_env_results_vllm( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, image_grid = all_prompt_image_grid, - pixel_value= all_prompt_pixel_value, + pixel_values= all_prompt_pixel_value, completion_ids=all_completion_ids, completion_mask=all_completion_masks, completion_logprobs=all_completion_logprobs, diff --git a/verifiers/types.py b/verifiers/types.py index 6fee00652..b9169e36b 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -89,7 +89,7 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] image_grid: list[list[int]] - pixel_values: list[list[int]] + pixel_values: list[list[float]] completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] From 5e3ac99c17434d7942cd786641d8c368b55636cb Mon Sep 17 00:00:00 2001 From: BLE Date: Tue, 23 Sep 2025 20:19:30 +0200 Subject: [PATCH 14/58] WIP : try to fix compute loss --- verifiers/trainers/grpo_trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index ffb2354b5..51f0dec6c 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -825,7 +825,6 @@ def _get_per_token_logps( attention_mask, logits_to_keep, batch_size=None, - compute_entropy=False, pixel_values=None, image_grid_thw=None, ) -> torch.Tensor: @@ -1349,9 +1348,8 @@ def compute_loss( # type: ignore input_ids, attention_mask, logits_to_keep, - compute_entropy=self.top_entropy_quantile < 1.0, - pixel_values=inputs.pixel_values, - image_grid_thw=inputs.image_grid_thw, + pixel_values=inputs.get("pixel_values"), + image_grid_thw=inputs.get("image_grid_thw"), ) # Compute the loss advantages = inputs["advantages"] @@ -1379,12 +1377,12 @@ def compute_loss( # type: ignore with torch.no_grad(): if self.ref_model is not None: ref_per_token_logps = self._get_per_token_logps( - self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.pixel_values,image_grid_thw=inputs.image_grid_thw, + self.ref_model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): # type: ignore ref_per_token_logps = self._get_per_token_logps( - self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.pixel_values,image_grid_thw=inputs.image_grid_thw, + self.model, input_ids, attention_mask, logits_to_keep, pixel_values=inputs.get("pixel_values"),image_grid_thw=inputs.get("image_grid_thw"), ) per_token_kl = ( torch.exp(ref_per_token_logps - per_token_logps) From b44d3f908da79d542d387ad60dbdda3e007185d1 Mon Sep 17 00:00:00 2001 From: ulrick Date: Tue, 23 Sep 2025 18:33:46 +0000 Subject: [PATCH 15/58] WIP : working until self._get_per_token_logps model(**model_inputs).logits --- verifiers/trainers/grpo_trainer.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 51f0dec6c..7478c04ea 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -5,6 +5,7 @@ from collections import defaultdict, deque from contextlib import nullcontext from typing import Any, Dict, List, Optional, Sized, Union +import inspect import datasets import numpy as np @@ -298,6 +299,12 @@ def __init__( ): self.logger = logging.getLogger(__name__) + self.model_kwarg_keys = ( + inspect.signature(model.forward).parameters.keys() + if not hasattr(model, "get_base_model") + else inspect.signature(model.get_base_model().forward).parameters.keys() + ) + # Models if peft_config is not None: model = get_peft_model(model, peft_config) # type: ignore @@ -314,7 +321,7 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True - + # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): tokenizer = processing_class.tokenizer @@ -835,18 +842,19 @@ def _get_per_token_logps( for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] + # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} # TODO : check if needed there or if already robust to VLM - # if image_grid_thw is not None and pixel_values is not None: - # model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - # start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - # end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() - # model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - # elif pixel_values is not None: - # model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + if image_grid_thw is not None and pixel_values is not None: + model_inputs["image_grid_thw"] = image_grid_thw[i : i + batch_size] + start_pixel_idx = image_grid_thw[:i].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: i + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[i : i + batch_size] # Only add logits_to_keep if the model supports it if "logits_to_keep" in self.model_kwarg_keys: From b4ddd581adeef831ca222d3cdf9170eb22597039 Mon Sep 17 00:00:00 2001 From: ulrick Date: Tue, 23 Sep 2025 18:51:36 +0000 Subject: [PATCH 16/58] FIX Batch for text data while keeping string for image, however, issue with pixel values after --- verifiers/trainers/grpo_trainer.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 7478c04ea..0c66dd739 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1047,14 +1047,21 @@ def _gather_batch_data(self, batch_offset: int = 0): for x in batch: prompt = x["prompt"] for message in prompt: - for content in message.get("content", []): - if content.get("type") == "image" : - img_url = pil_to_base64_url(x["image"]) - content.clear() - content.update({ - "type": "image_url", - "image_url": {"url":img_url} - }) + content = message.get("content", []) + if isinstance(content, list): + for c in content: + if isinstance(c, dict) and c.get("type") == "image": + img_url = pil_to_base64_url(x["image"]) + c.clear() + c.update({ + "type": "image_url", + "image_url": {"url": img_url} + }) + elif isinstance(content, str): + pass + else: + print("Unknown content type:", type(content)) + prompts.append(prompt) answers = [x["answer"] for x in batch] From dd4ed6db4eb60638a6f1185aea471e75c914abdc Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 24 Sep 2025 21:27:52 +0200 Subject: [PATCH 17/58] fix issue with shape --- verifiers/envs/environment.py | 2 +- verifiers/trainers/grpo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 7efc86ccc..c2fcd742c 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -86,7 +86,7 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) - return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"][0].tolist() + return inputs["input_ids"], inputs["image_grid_thw"], inputs["pixel_values"] else: raise TypeError(f"Unsupported processing_class: {type(processing_class)}") diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 51f0dec6c..d99b0a04e 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1234,7 +1234,7 @@ def _prepare_inputs( # type: ignore device=self.accelerator.device, ) ) - pixel_values_list.append( + pixel_values_list.append( # TODO : ici je pense qu'il faut gérer si on stack ensuite sur la première dimension ou pas torch.tensor( broadcast_data["pixel_values"][i], device=self.accelerator.device, From 58a39cebd050b2022ab89267135cb29f2dadce90 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 24 Sep 2025 21:32:14 +0200 Subject: [PATCH 18/58] check pixel values issue --- verifiers/envs/environment.py | 3 +++ verifiers/trainers/grpo_trainer.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index c2fcd742c..1898bc14e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -86,6 +86,7 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) + print("encode_chat_with_processor",inputs["pixel_values"].shape) return inputs["input_ids"], inputs["image_grid_thw"], inputs["pixel_values"] else: @@ -983,6 +984,8 @@ def process_env_results_vllm( all_rewards.append(0) else: all_rewards.append(reward) + + print("process_env_results_vllm",all_prompt_pixel_value.shape) return ProcessedOutputs( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 8dff45bf1..9f197da39 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -846,7 +846,7 @@ def _get_per_token_logps( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - + print("_get_per_token_logps",pixel_values.shape) # TODO : check if needed there or if already robust to VLM if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[i : i + batch_size] From ba0302dac953ab514e981578c402f13ba1c15cf9 Mon Sep 17 00:00:00 2001 From: ulrick Date: Thu, 25 Sep 2025 20:13:57 +0000 Subject: [PATCH 19/58] WIP : VL training working end to end, but rebus too complex -> need to improve prompt so some pass + need to fix for text only --- verifiers/envs/environment.py | 5 +--- verifiers/trainers/grpo_trainer.py | 35 ++++++++++++++++------- verifiers/types.py | 5 ++-- verifiers/utils/logging_utils.py | 46 +++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 1898bc14e..ee9cd3176 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -86,8 +86,7 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) - print("encode_chat_with_processor",inputs["pixel_values"].shape) - return inputs["input_ids"], inputs["image_grid_thw"], inputs["pixel_values"] + return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() else: raise TypeError(f"Unsupported processing_class: {type(processing_class)}") @@ -529,7 +528,6 @@ def generate( if isinstance(client, OpenAI): client = AsyncOpenAI(api_key=client.api_key, base_url=client.base_url) - print("inputs",inputs) coro = self.a_generate( inputs, client, @@ -985,7 +983,6 @@ def process_env_results_vllm( else: all_rewards.append(reward) - print("process_env_results_vllm",all_prompt_pixel_value.shape) return ProcessedOutputs( prompt_ids=all_prompt_ids, prompt_mask=all_prompt_masks, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 9f197da39..eb5711ef3 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -32,7 +32,7 @@ from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.trainers.grpo_config import GRPOConfig -from verifiers.utils.logging_utils import print_prompt_completions_sample +from verifiers.utils.logging_utils import print_prompt_completions_sample, serialize_for_wandb class RepeatSampler(Sampler): @@ -415,8 +415,6 @@ def __init__( eval_dataset = env.get_eval_dataset() - print("data", train_dataset) - print(train_dataset.column_names) if "prompt" not in train_dataset.column_names: raise ValueError("Train dataset must contain a 'prompt' column") if "answer" not in train_dataset.column_names: @@ -846,8 +844,6 @@ def _get_per_token_logps( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - print("_get_per_token_logps",pixel_values.shape) - # TODO : check if needed there or if already robust to VLM if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[i : i + batch_size] start_pixel_idx = image_grid_thw[:i].prod(-1).sum().item() @@ -862,8 +858,10 @@ def _get_per_token_logps( model_inputs["logits_to_keep"] = logits_to_keep + 1 logits = model(**model_inputs).logits + # Exclude the last value: it corresponds to the next token pred logits = logits[:, :-1, :] # (B, L-1, H) + input_ids_batch = input_ids_batch[:, -logits_to_keep:] # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H) # Divide logits by sampling temperature. @@ -1029,7 +1027,6 @@ def _gather_batch_data(self, batch_offset: int = 0): """ Gather batch data from all processes and convert PIL images in prompts to base64 image_url. """ - print("GATHERING BATCH DATA") batches = self._async_dataloader.peek_ahead(batch_offset) if batch_offset == 0: @@ -1152,8 +1149,7 @@ def _prepare_inputs( # type: ignore "answer": all_answers, "task": all_tasks, "info": all_infos, - } - + } # Submit batch (main process only) if self.accelerator.is_main_process: request = BatchRequest( @@ -1186,7 +1182,6 @@ def _prepare_inputs( # type: ignore if self.accelerator.is_main_process: # Get batch result - # TODO : toute le get vllm et préparation sert à arriver ici, est-ce que les grid et pixel values sont nécessaires ? je pense que oui pour la backward batch_result = self.async_generator.get_batch(batch_id_to_retrieve) processed_results = batch_result.processed_results @@ -1633,8 +1628,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: print_prompt_completions_sample( self._logs["prompt"], self._logs["completion"], - self._logs["rewards"], - self._logs["advantages"], + self._logs["rewards"]["reward"], self.state.global_step, ) @@ -1666,9 +1660,14 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: table["image"].append(None) if len(table["prompt"]) > 0: + table["prompt"] = [serialize_for_wandb(p) for p in table["prompt"]] + table["completion"] = [serialize_for_wandb(c) for c in table["completion"]] + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) # Clear the textual logs after logging @@ -1731,6 +1730,20 @@ def _log_textual_data_primary( else reward_values ) + def _log_image_data_primary(self, all_images: List[Any]) -> None: + """ + Log images for wandb (PRIMARY PROCESS ONLY). + Converts each image to wandb.Image and stores it in the _logs deque. + """ + if "image" not in self._logs: + self._logs["image"] = deque(maxlen=self._logs_maxlen) + + for img in all_images: + if img is not None: + self._logs["image"].append(wandb.Image(img)) + else: + self._logs["image"].append(None) + def _log_completion_metrics_primary( self, mode: str, diff --git a/verifiers/types.py b/verifiers/types.py index 6b0f883cf..c3bcb3dc4 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -5,6 +5,7 @@ Awaitable, Callable, Literal, + Optional ) from openai.types.chat.chat_completion import ChatCompletion @@ -88,8 +89,8 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] - image_grid_thw: list[list[int]] - pixel_values: list[list[float]] + image_grid_thw: Optional[list[list[int]]] = None + pixel_values: Optional[list[list[list[float]]]] = None completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index fba11437a..267e30144 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -1,6 +1,7 @@ import json import logging import sys +import copy from rich.console import Console from rich.panel import Panel @@ -10,6 +11,47 @@ from verifiers.types import Messages from collections.abc import Mapping +def sanitize_and_serialize(obj): + """ + Sanitize Base64 images and convert nested dict/list to string for WandB. + """ + if isinstance(obj, dict): + obj = {k: sanitize_and_serialize(v) for k, v in obj.items()} + if "image_url" in obj and isinstance(obj["image_url"], dict): + url = obj["image_url"].get("url") + if isinstance(url, str) and url.startswith("data:image/"): + obj["image_url"]["url"] = "" + return obj + elif isinstance(obj, list): + return [sanitize_and_serialize(x) for x in obj] + else: + return obj + +def serialize_for_wandb(obj): + sanitized = sanitize_and_serialize(obj) + return json.dumps(sanitized, ensure_ascii=False) + + +def sanitize_message_for_logging(msg): + """ + Recursively sanitize a message dict, removing Base64 data from image URLs. + """ + msg = copy.deepcopy(msg) + + if isinstance(msg, dict): + for k, v in msg.items(): + if k == "image_url" and isinstance(v, dict) and "url" in v: + url = v["url"] + if url.startswith("data:image/"): + v["url"] = "" + else: + msg[k] = sanitize_message_for_logging(v) + + elif isinstance(msg, list): + msg = [sanitize_message_for_logging(x) for x in msg] + + return msg + def setup_logging( level: str = "INFO", @@ -86,7 +128,9 @@ def _format_messages(messages) -> Text: style = "bright_cyan" if role == "assistant" else "bright_magenta" out.append(f"{role}: ", style="bold") - out.append(content, style=style) + + safe_content = sanitize_message_for_logging(content) + out.append(str(safe_content), style=style) for tc in msg.get("tool_calls") or []: # treat None as empty list payload = _normalize_tool_call(tc) From 3004416b23fa5122594152be57a7876875a16b7e Mon Sep 17 00:00:00 2001 From: ulrick Date: Sun, 28 Sep 2025 09:26:41 +0000 Subject: [PATCH 20/58] update typing --- verifiers/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/types.py b/verifiers/types.py index c3bcb3dc4..0a32b6458 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -89,8 +89,8 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] - image_grid_thw: Optional[list[list[int]]] = None - pixel_values: Optional[list[list[list[float]]]] = None + image_grid_thw: Optional[list[Optional[list[int]]]] = None + pixel_values: Optional[list[Optional[list[list[float]]]]] = None completion_ids: list[list[int]] completion_mask: list[list[int]] completion_logprobs: list[list[float]] From 5209a042c9e2fbeb388830b8105bd57688260af3 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 28 Sep 2025 20:25:35 +0200 Subject: [PATCH 21/58] add image logging and answer logging in wandb to improve data diging --- verifiers/envs/environment.py | 4 ++- verifiers/trainers/async_batch_generator.py | 6 ++-- verifiers/trainers/grpo_trainer.py | 22 +++++++++----- verifiers/utils/logging_utils.py | 32 +++++++++++++++++++++ 4 files changed, 54 insertions(+), 10 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ee9cd3176..91f764d08 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -918,6 +918,8 @@ def process_env_results_vllm( all_completion_masks = [] all_completion_logprobs = [] all_rewards = [] + all_images=[] + all_answers=[] for i, (prompt, completion, state, reward) in enumerate( zip(prompts, completions, states, rewards) ): @@ -978,6 +980,7 @@ def process_env_results_vllm( all_completion_ids.append(completion_ids) all_completion_masks.append(completion_mask) all_completion_logprobs.append(completion_logprobs) + if zero_truncated_completions and is_truncated: all_rewards.append(0) else: @@ -992,7 +995,6 @@ def process_env_results_vllm( completion_mask=all_completion_masks, completion_logprobs=all_completion_logprobs, rewards=all_rewards, - ) # alias for process_env_results_vllm diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index b7c44a96d..279a697b3 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -4,7 +4,7 @@ import threading import time from collections import deque -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, Field @@ -38,6 +38,7 @@ class BatchResult(BaseModel): default_factory=list ) # Store completions for logging prompts: list[Any] = Field(default_factory=list) # Store prompts for logging + answers : list[Any] class AsyncBatchGenerator: @@ -292,7 +293,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: mask_env_responses=request.mask_env_responses, mask_truncated_completions=request.mask_truncated_completions, zero_truncated_completions=request.zero_truncated_completions, - ) + ) return BatchResult( batch_id=request.batch_id, @@ -300,6 +301,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: all_reward_dict=all_reward_dict, completions=env_results.completion, prompts=env_results.prompt, + answers=request.env_inputs.answer ) async def _evaluate_async(self, num_samples: int = -1) -> GenerateOutputs: diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index eb5711ef3..94c3bfdd0 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -32,7 +32,7 @@ from verifiers.trainers.async_batch_generator import AsyncBatchGenerator, BatchRequest from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.trainers.grpo_config import GRPOConfig -from verifiers.utils.logging_utils import print_prompt_completions_sample, serialize_for_wandb +from verifiers.utils.logging_utils import print_prompt_completions_sample, serialize_for_wandb, extract_images class RepeatSampler(Sampler): @@ -1197,6 +1197,7 @@ def _prepare_inputs( # type: ignore "prompts": batch_result.prompts, "pixel_values":processed_results.pixel_values, "image_grid_thw":processed_results.image_grid_thw, + "answer": batch_result.answers, } else: broadcast_data = None @@ -1629,6 +1630,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._logs["prompt"], self._logs["completion"], self._logs["rewards"]["reward"], + self._logs["answer"], self.state.global_step, ) @@ -1647,6 +1649,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._sanitize_tool_calls(c) for c in self._logs["completion"] ], + "answer" : list(self._logs["answers"]) **{k: list(v) for k, v in self._logs["rewards"].items()}, } @@ -1654,27 +1657,30 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: table["image"] = [] for img in self._logs["image"]: if img is not None: - # Convert images to wandb Image objects for proper visualization table["image"].append(wandb.Image(img)) else: table["image"].append(None) if len(table["prompt"]) > 0: + all_images = [extract_images(p) for p in table["prompt"]] # list of lists + wandb_images = [[wandb.Image(img) for img in imgs] for imgs in all_images] + + if any(len(imgs) > 0 for imgs in wandb_images): + table["images"] = wandb_images + table["prompt"] = [serialize_for_wandb(p) for p in table["prompt"]] table["completion"] = [serialize_for_wandb(c) for c in table["completion"]] - + df = pd.DataFrame(table) - + if self.wandb_log_unique_prompts: df = df.drop_duplicates(subset=["prompt"]) - + wandb.log({"completions": wandb.Table(dataframe=df)}) # Clear the textual logs after logging self._logs["prompt"].clear() self._logs["completion"].clear() - if self._logs["image"] : - self._logs["image"].clear() for key in self._logs["rewards"]: self._logs["rewards"][key].clear() @@ -1713,6 +1719,7 @@ def _log_textual_data_primary( all_prompts: List[Union[str, List[Dict[str, Any]]]], all_completions: List[Union[str, List[Dict[str, Any]]]], all_reward_dict: Dict[str, Any], + all_answers : List[Any] ) -> None: """ Log textual data for wandb (PRIMARY PROCESS ONLY). @@ -1720,6 +1727,7 @@ def _log_textual_data_primary( """ self._logs["prompt"].extend(all_prompts) self._logs["completion"].extend(all_completions) + self._logs["answers"].extend(all_answers) # Log all reward scores - both individual functions and consolidated for reward_key in all_reward_dict: diff --git a/verifiers/utils/logging_utils.py b/verifiers/utils/logging_utils.py index 267e30144..03998b323 100644 --- a/verifiers/utils/logging_utils.py +++ b/verifiers/utils/logging_utils.py @@ -11,6 +11,38 @@ from verifiers.types import Messages from collections.abc import Mapping +import base64 +from io import BytesIO +from PIL import Image + + +def extract_images(obj): + """ + Extract and decode Base64 images into a list of PIL.Image objects. + """ + images = [] + + def _extract(o): + if isinstance(o, dict): + for v in o.values(): + _extract(v) + if "image_url" in o and isinstance(o["image_url"], dict): + url = o["image_url"].get("url") + if isinstance(url, str) and url.startswith("data:image/"): + try: + header, b64_data = url.split(",", 1) + image_data = base64.b64decode(b64_data) + image = Image.open(BytesIO(image_data)) + images.append(image) + except Exception: + pass + elif isinstance(o, list): + for v in o: + _extract(v) + + _extract(obj) + return images + def sanitize_and_serialize(obj): """ Sanitize Base64 images and convert nested dict/list to string for WandB. From 47c1df0c0cac3b9a9257bf3dbf7a1a395cd92204 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 28 Sep 2025 20:35:51 +0200 Subject: [PATCH 22/58] make code robust to text only --- verifiers/trainers/grpo_trainer.py | 39 ++++++++++++++++++------------ 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 94c3bfdd0..7e786c067 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1229,6 +1229,11 @@ def _prepare_inputs( # type: ignore attention_mask_list = [] pixel_values_list = [] image_grid_list = [] + + has_images = any( + broadcast_data["pixel_values"][i] is not None + for i in range(process_slice.start, process_slice.stop) + ) for i in range(process_slice.start, process_slice.stop): input_ids_list.append( @@ -1245,18 +1250,18 @@ def _prepare_inputs( # type: ignore device=self.accelerator.device, ) ) - pixel_values_list.append( # TODO : ici je pense qu'il faut gérer si on stack ensuite sur la première dimension ou pas - torch.tensor( - broadcast_data["pixel_values"][i], - device=self.accelerator.device, - ) - ) - image_grid_list.append( - torch.tensor( - broadcast_data["image_grid_thw"][i], - device=self.accelerator.device, - ) - ) + if has_images: + if broadcast_data["pixel_values"][i] is not None: + pixel_values_list.append( + torch.tensor(broadcast_data["pixel_values"][i], device=self.accelerator.device) + ) + image_grid_list.append( + torch.tensor(broadcast_data["image_grid_thw"][i], device=self.accelerator.device) + ) + else: + # If some examples have no image insert dummy with correct shape + pixel_values_list.append(torch.zeros_like(pixel_values_list[0])) + image_grid_list.append(torch.zeros_like(image_grid_list[0])) input_ids = pad( input_ids_list, @@ -1265,8 +1270,9 @@ def _prepare_inputs( # type: ignore ) # type: ignore attention_mask = pad(attention_mask_list, padding_side="right") # type: ignore - pixel_values = torch.stack(pixel_values_list, dim=0) - image_grid_thw = torch.stack(image_grid_list, dim=0) + if has_images: + pixel_values = torch.stack(pixel_values_list, dim=0) + image_grid_thw = torch.stack(image_grid_list, dim=0) # Truncate if needed if self.max_seq_len is not None and input_ids.size(1) > self.max_seq_len: @@ -1305,9 +1311,10 @@ def _prepare_inputs( # type: ignore "attention_mask": attention_mask, "old_per_token_logps": None, "advantages": advantages, - "pixel_values": pixel_values, - "image_grid_thw": image_grid_thw } + if has_images: + full_batch["pixel_values"] = pixel_values + full_batch["image_grid_thw"] = image_grid_thw # Shuffle and split for gradient accumulation full_batch = shuffle_tensor_dict(full_batch) From 1b515a8e9062ff715b41890d8044c5748b5dbe52 Mon Sep 17 00:00:00 2001 From: ulrick Date: Mon, 29 Sep 2025 18:54:48 +0000 Subject: [PATCH 23/58] logging of image and answer working + full training works, maybe issue on the reward calculation of the env for japaneese --- verifiers/envs/environment.py | 1 - verifiers/trainers/async_batch_generator.py | 2 +- verifiers/trainers/grpo_trainer.py | 10 ++++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 91f764d08..a2d13760e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -125,7 +125,6 @@ def __init__( self.logger.warning( "The parser and rubric parser are different. This may cause unexpected behavior." ) - if self.message_type == "chat": if dataset is not None: self.dataset = self.format_dataset( diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 279a697b3..4cbe123f9 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -301,7 +301,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: all_reward_dict=all_reward_dict, completions=env_results.completion, prompts=env_results.prompt, - answers=request.env_inputs.answer + answers=request.env_inputs.get("answer") ) async def _evaluate_async(self, num_samples: int = -1) -> GenerateOutputs: diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 7e786c067..b89188545 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -414,7 +414,7 @@ def __init__( assert train_dataset is not None eval_dataset = env.get_eval_dataset() - + if "prompt" not in train_dataset.column_names: raise ValueError("Train dataset must contain a 'prompt' column") if "answer" not in train_dataset.column_names: @@ -591,6 +591,7 @@ def data_collator(features): "prompt": deque(maxlen=maxlen), "completion": deque(maxlen=maxlen), "rewards": defaultdict(lambda: deque(maxlen=maxlen)), + "answers": deque(maxlen=maxlen), } # OpenAI client for Environment generation (using vLLM server) @@ -1197,7 +1198,7 @@ def _prepare_inputs( # type: ignore "prompts": batch_result.prompts, "pixel_values":processed_results.pixel_values, "image_grid_thw":processed_results.image_grid_thw, - "answer": batch_result.answers, + "answers": batch_result.answers, } else: broadcast_data = None @@ -1295,6 +1296,7 @@ def _prepare_inputs( # type: ignore all_prompts=broadcast_data["prompts"], all_completions=broadcast_data["completions"], all_reward_dict=broadcast_data["all_reward_dict"], + all_answers=broadcast_data["answers"], ) # Log completion metrics using full batch data on CPU to save memory @@ -1637,7 +1639,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._logs["prompt"], self._logs["completion"], self._logs["rewards"]["reward"], - self._logs["answer"], + self._logs["answers"], self.state.global_step, ) @@ -1656,7 +1658,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self._sanitize_tool_calls(c) for c in self._logs["completion"] ], - "answer" : list(self._logs["answers"]) + "answer" : list(self._logs["answers"]), **{k: list(v) for k, v in self._logs["rewards"].items()}, } From 277ed435688859fced274f91a4c3cf33d258398d Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:06:12 +0200 Subject: [PATCH 24/58] fix ruff --- notes.md | 8 -------- verifiers/envs/environment.py | 13 +++++-------- 2 files changed, 5 insertions(+), 16 deletions(-) delete mode 100644 notes.md diff --git a/notes.md b/notes.md deleted file mode 100644 index 8db212e2c..000000000 --- a/notes.md +++ /dev/null @@ -1,8 +0,0 @@ -La génération se fait avec OpenAI dans les envs qui est déjà robuste aux images dans prompt : ok pour ça - -Il faut juste check dans le trainer je pense - -Il faut que dans l'input il y ait images en optionnel qui sera en input de la processing classe ? - - -Est-ce que dans _prepare_inputs il faut gérer les images ? ou seulement dans _get_per_token_logps ? \ No newline at end of file diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index a2d13760e..0fabb5681 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,11 +4,14 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import TYPE_CHECKING, Literal, Union -from transformers import ProcessorMixin +from typing import TYPE_CHECKING, Literal, Union, List, Dict from datasets import Dataset from openai import AsyncOpenAI, OpenAI from transformers.tokenization_utils_base import PreTrainedTokenizerBase +import base64 +from io import BytesIO +from PIL import Image +from transformers import PreTrainedTokenizerBase, ProcessorMixin from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric from verifiers.types import ( @@ -29,12 +32,6 @@ ) from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls -import base64 -from io import BytesIO -from typing import List, Dict, Union -from PIL import Image -from transformers import PreTrainedTokenizerBase, ProcessorMixin - if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore PreTrainedTokenizerBase, From c6c0d677ed9350f2cd637700663406cbfdaa6061 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:06:53 +0200 Subject: [PATCH 25/58] fix ruff --- verifiers/envs/environment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 0fabb5681..b2e5c35ab 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Literal, Union, List, Dict from datasets import Dataset from openai import AsyncOpenAI, OpenAI -from transformers.tokenization_utils_base import PreTrainedTokenizerBase import base64 from io import BytesIO from PIL import Image From 6652dacde8edd29104ea6c4124777543c03c9521 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:07:43 +0200 Subject: [PATCH 26/58] fix ruff --- verifiers/envs/environment.py | 1 - 1 file changed, 1 deletion(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index b2e5c35ab..79016fc2a 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -789,7 +789,6 @@ def process_chat_format_vllm( else: completion_turn_mask = [1] * len(completion_turn_ids) - # TODO : what about images in turn ? for now we consider we hav'nt completion_turn_logprobs = [0.0] * len(completion_turn_ids) completion_ids.extend(completion_turn_ids) completion_mask.extend(completion_turn_mask) From 6dc9d405d8d49f433e11ab1f1a0979f73bfc01bf Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:08:20 +0200 Subject: [PATCH 27/58] fix ruff --- verifiers/envs/environment.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 79016fc2a..0a2161236 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -912,8 +912,6 @@ def process_env_results_vllm( all_completion_masks = [] all_completion_logprobs = [] all_rewards = [] - all_images=[] - all_answers=[] for i, (prompt, completion, state, reward) in enumerate( zip(prompts, completions, states, rewards) ): From 289dc9fd7e17d5c067a2c9cf69e624b21acde326 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:08:49 +0200 Subject: [PATCH 28/58] fix ruff --- verifiers/trainers/async_batch_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 4cbe123f9..43144f46a 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -4,7 +4,7 @@ import threading import time from collections import deque -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field From d1f5f0785f2a739d4b3b16bbaa142c1ed69afb4b Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:09:50 +0200 Subject: [PATCH 29/58] fix ruff --- verifiers/trainers/async_batch_generator.py | 2 +- verifiers/trainers/grpo_trainer.py | 31 --------------------- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 43144f46a..d56bd179d 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -293,7 +293,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: mask_env_responses=request.mask_env_responses, mask_truncated_completions=request.mask_truncated_completions, zero_truncated_completions=request.zero_truncated_completions, - ) + ) return BatchResult( batch_id=request.batch_id, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b89188545..c33bfc71c 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -213,37 +213,6 @@ def shuffle_tensor_dict( key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items() } - -def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]: - """ - Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in - `batch["image_grid_thw"]`, while keeping other entries unchanged. - """ - if "image_grid_thw" not in batch or "pixel_values" not in batch: - return batch - - lengths = batch["image_grid_thw"].prod(dim=1).tolist() # [batch_size] - pixel_values = batch["pixel_values"] # [total, feature_dim] - - if sum(lengths) != pixel_values.size(0): - raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") - - split_values = list(torch.split(batch["pixel_values"], lengths, dim=0)) - return {**batch, "pixel_values": split_values} - - -def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]: - """ - Opposite of `split_pixel_values_by_grid`. Merges a list of tensors in `batch["pixel_values"]` - back into a single tensor along the first dimension. - """ - pixel_values = batch.pixel_values - - if isinstance(pixel_values, list): - merged = torch.cat(pixel_values, dim=0) - return {**batch, "pixel_values": merged} - else: - return batch def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ From 8e1abbebdc12fb6f9fc7c60a28acd4e3002d3186 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:16:23 +0200 Subject: [PATCH 30/58] fix ruff --- pyproject.toml | 1 + verifiers/trainers/grpo_trainer.py | 33 ++++++------------------------ 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a2fa24600..229ebbdfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "rich", "textual", "openai-agents>=0.0.7", + "pillow>=10.0.0" ] [project.optional-dependencies] diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c33bfc71c..5890fef41 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -383,7 +383,7 @@ def __init__( assert train_dataset is not None eval_dataset = env.get_eval_dataset() - + if "prompt" not in train_dataset.column_names: raise ValueError("Train dataset must contain a 'prompt' column") if "answer" not in train_dataset.column_names: @@ -764,34 +764,13 @@ def _get_last_hidden_state( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - # TODO : check if needed or no - # # For Qwen models: - # if image_grid_thw is not None and pixel_values is not None: - # model_inputs["image_grid_thw"] = image_grid_thw - # # For Gemma, SmolVLM2, LLaVa-Next etc.: - # if pixel_values is not None: - # model_inputs["pixel_values"] = pixel_values - # # For SmolVLM2 - # if pixel_attention_mask is not None: - # model_inputs["pixel_attention_mask"] = pixel_attention_mask - # # For LLaVa-Next - # if image_sizes is not None: - # model_inputs["image_sizes"] = image_sizes - - # # Only add logits_to_keep if the model supports it - # if "logits_to_keep" in self.model_kwarg_keys: - # # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - # model_inputs["logits_to_keep"] = logits_to_keep + 1 - - # model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state # Exclude the last value: it corresponds to the next token pred last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) return last_hidden_state - + # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps( self, @@ -998,18 +977,18 @@ def _gather_batch_data(self, batch_offset: int = 0): Gather batch data from all processes and convert PIL images in prompts to base64 image_url. """ batches = self._async_dataloader.peek_ahead(batch_offset) - + if batch_offset == 0: batch = batches[0] if batches else None else: batch = batches[batch_offset - 1] if batches else None - + if batch is None: return [], [], [], [] - + if isinstance(batch, dict): batch = [batch] - + prompts = [] for x in batch: prompt = x["prompt"] From 1971169ae18682ae1151b668aecfb443f114a781 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:20:12 +0200 Subject: [PATCH 31/58] fix ruff --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 229ebbdfd..a59cc89ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,8 @@ dependencies = [ "rich", "textual", "openai-agents>=0.0.7", - "pillow>=10.0.0" + "pillow>=10.0.0", + "transformers" ] [project.optional-dependencies] @@ -45,7 +46,6 @@ all = [ "pytest-cov>=4.0.0", "requests", "torch>=2.7.0", - "transformers", "accelerate>=1.4.0", "deepspeed", "peft", From c5433031d9def963614f2799ce88c13b9b01676b Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:32:09 +0200 Subject: [PATCH 32/58] modify encode so it is robust to tests --- verifiers/envs/environment.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 0a2161236..ab50bc10a 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -55,14 +55,8 @@ def encode_chat_with_processor( Apply chat template and return token IDs, handling both tokenizer and processor. Supports base64-encoded images in the conversation. """ - if isinstance(processing_class, PreTrainedTokenizerBase): - prompt_ids : List[int] = processing_class.apply_chat_template( - conversation=conversation, - add_generation_prompt=add_generation_prompt, - ) - return prompt_ids,None,None - elif isinstance(processing_class, ProcessorMixin): + if isinstance(processing_class, ProcessorMixin): prompt_text = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, @@ -85,7 +79,11 @@ def encode_chat_with_processor( return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() else: - raise TypeError(f"Unsupported processing_class: {type(processing_class)}") + prompt_ids : List[int] = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + ) + return prompt_ids,None,None class Environment(ABC): """ From c043abf6de3687dccfd36da18561bd595684b98c Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:39:19 +0200 Subject: [PATCH 33/58] change test process chat format to adapt to new output --- tests/test_environment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_environment.py b/tests/test_environment.py index 759dd841c..9a442ca5b 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -261,6 +261,8 @@ def apply_template(conversation, tokenize=False, add_generation_prompt=True): ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, From b1e917fd66eacbc5a96bc4ae93ebe0603b3bf718 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 21:42:15 +0200 Subject: [PATCH 34/58] change test process chat format to adapt to new output --- tests/test_environment.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_environment.py b/tests/test_environment.py index 9a442ca5b..08c79ea32 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -308,6 +308,8 @@ def test_process_completion_format(self, mock_openai_client, sample_dataset): ( prompt_ids, prompt_mask, + prompt_image_grid, + prompt_pixel_value, completion_ids, completion_mask, completion_logprobs, From 55ad4064b076de6f08df639cca27010771f8b10a Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 22:16:42 +0200 Subject: [PATCH 35/58] change test process chat format to adapt to new output --- verifiers/envs/environment.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ab50bc10a..16d5d0aec 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -85,6 +85,31 @@ def encode_chat_with_processor( ) return prompt_ids,None,None +def encode_text_with_processor( + text: str, + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + add_special_tokens: bool = False, +) -> tuple[list[int], Any, Any]: + """ + Encode plain text and return token IDs, handling both tokenizer and processor. + """ + if isinstance(processing_class, ProcessorMixin): + inputs = processing_class( + text=[text], + images=None, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + input_ids = inputs["input_ids"][0].tolist() + image_grid = inputs.get("image_grid_thw", [None])[0].tolist() + pixel_values = inputs.get("pixel_values", [None]).tolist() + return input_ids, image_grid, pixel_values + else: + prompt_ids: list[int] = processing_class.encode( + text, add_special_tokens=add_special_tokens + ) + return prompt_ids, None, None + class Environment(ABC): """ Base class for all environments. @@ -831,10 +856,10 @@ def process_completion_format_vllm( idx = response_start_idx + len(response_text) assert idx == len(completion), "Completion not fully consumed" - prompt_ids, prompt_image_grid, prompt_pixel_value = encode_chat_with_processor( - conversation=prompt, + prompt_ids, prompt_image_grid, prompt_pixel_value = encode_text_with_processor( + text=prompt, # The prompt is a string for completion format processing_class=processing_class, - add_generation_prompt=False, + add_special_tokens=False, ) rollout_consumed = prompt prompt_mask: list[int] = [0] * len(prompt_ids) From 8a37577f0347cd37af10d5973d801bc4ca601627 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 22:18:08 +0200 Subject: [PATCH 36/58] change test process chat format to adapt to new output --- verifiers/envs/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 16d5d0aec..b3fbc7bfb 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import TYPE_CHECKING, Literal, Union, List, Dict +from typing import TYPE_CHECKING, Literal, Union, List, Dict, Any from datasets import Dataset from openai import AsyncOpenAI, OpenAI import base64 From bbebab75203fcf5b31e4b519a171bef77f1a3af6 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 1 Oct 2025 22:21:55 +0200 Subject: [PATCH 37/58] modif encoding to be robust to text only --- verifiers/envs/environment.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index b3fbc7bfb..26d196d0b 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -88,7 +88,6 @@ def encode_chat_with_processor( def encode_text_with_processor( text: str, processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], - add_special_tokens: bool = False, ) -> tuple[list[int], Any, Any]: """ Encode plain text and return token IDs, handling both tokenizer and processor. @@ -98,7 +97,6 @@ def encode_text_with_processor( text=[text], images=None, return_tensors="pt", - add_special_tokens=add_special_tokens, ) input_ids = inputs["input_ids"][0].tolist() image_grid = inputs.get("image_grid_thw", [None])[0].tolist() @@ -106,7 +104,7 @@ def encode_text_with_processor( return input_ids, image_grid, pixel_values else: prompt_ids: list[int] = processing_class.encode( - text, add_special_tokens=add_special_tokens + text ) return prompt_ids, None, None @@ -857,9 +855,8 @@ def process_completion_format_vllm( assert idx == len(completion), "Completion not fully consumed" prompt_ids, prompt_image_grid, prompt_pixel_value = encode_text_with_processor( - text=prompt, # The prompt is a string for completion format + text=prompt, processing_class=processing_class, - add_special_tokens=False, ) rollout_consumed = prompt prompt_mask: list[int] = [0] * len(prompt_ids) From ae412a5d263b5e71fc68c066909b7e64e5d902b1 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 21:39:08 +0200 Subject: [PATCH 38/58] remove transformers from base dependencies --- pyproject.toml | 24 ------------------------ verifiers/envs/environment.py | 3 +-- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cf7ca804e..413cc4f73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,30 +38,6 @@ dependencies = [ "textual", "openai-agents>=0.0.7", "pillow>=10.0.0", - "transformers" -] - -[project.optional-dependencies] -all = [ - "ruff", - "pre-commit", - "pytest>=7.0.0", - "pytest-asyncio>=0.21.0", - "pytest-cov>=4.0.0", - "requests", - "torch>=2.7.0", - "accelerate>=1.4.0", - "deepspeed", - "peft", - "wandb", - "trl>=0.17.0", - "vllm>=0.9.2", - "liger-kernel>=0.5.10", - "nest-asyncio>=1.6.0", - "ipykernel", - "ipywidgets", - "math-verify>=0.8.0", - "pydantic>=2.11.9", ] [dependency-groups] diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e9f378d11..7c9ee8465 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -10,7 +10,6 @@ import base64 from io import BytesIO from PIL import Image -from transformers import PreTrainedTokenizerBase, ProcessorMixin from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric from verifiers.types import ( @@ -33,7 +32,7 @@ if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore - PreTrainedTokenizerBase, + PreTrainedTokenizerBase, ProcessorMixin ) def _base64_to_pil(data_uri: str) -> Image.Image: From d2b6166052c7f37e6fad0b7cbd6f8ac3fd43d834 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:02:41 +0200 Subject: [PATCH 39/58] in grpo trainer, only convert images from PIL if not base64 --- verifiers/trainers/grpo_trainer.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 74449d163..1d2c1939b 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1021,20 +1021,25 @@ def _gather_batch_data(self, batch_offset: int = 0): content = message.get("content", []) if isinstance(content, list): for c in content: - if isinstance(c, dict) and c.get("type") == "image": - img_url = pil_to_base64_url(x["image"]) - c.clear() - c.update({ - "type": "image_url", - "image_url": {"url": img_url} - }) + if isinstance(c, dict): + if c.get("type") == "image": # Convert only if not already base64 + if "image_url" not in c: + if "image" in x: # only convert if PIL image exists + img_url = pil_to_base64_url(x["image"]) + c.clear() + c.update({ + "type": "image_url", + "image_url": {"url": img_url} + }) + elif c.get("type") == "image_url": # Already base64, leave as is + pass elif isinstance(content, str): pass else: print("Unknown content type:", type(content)) prompts.append(prompt) - + answers = [x["answer"] for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] From 02caa79e6d91427a14748d7f8a1cd0615144f363 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:21:18 +0200 Subject: [PATCH 40/58] Create image utils for base64 and PIL transformation, move processing of outputs and inputs to model utils with lazy import --- verifiers/envs/environment.py | 76 +--------------------------------- verifiers/utils/model_utils.py | 71 ++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 77 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 7c9ee8465..e9575db2c 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -7,9 +7,6 @@ from typing import TYPE_CHECKING, Literal, Union, List, Dict, Any from datasets import Dataset from openai import AsyncOpenAI, OpenAI -import base64 -from io import BytesIO -from PIL import Image from verifiers.parsers.parser import Parser from verifiers.rubrics.rubric import Rubric from verifiers.types import ( @@ -29,83 +26,12 @@ State, ) from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls +from verifiers.utils.model_utils import encode_text_with_processor, encode_chat_with_processor if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore PreTrainedTokenizerBase, ProcessorMixin ) - -def _base64_to_pil(data_uri: str) -> Image.Image: - """Convert a base64 data URI (data:image/...;base64,...) to a PIL Image.""" - if not data_uri.startswith("data:image"): - raise ValueError(f"Expected base64 image data URI, got: {data_uri[:30]}") - header, b64data = data_uri.split(",", 1) - image_data = base64.b64decode(b64data) - return Image.open(BytesIO(image_data)).convert("RGB") - - -def encode_chat_with_processor( - conversation: List[Dict], - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], - add_generation_prompt: bool = False, - add_special_tokens: bool = False, -) -> List[int]: - """ - Apply chat template and return token IDs, handling both tokenizer and processor. - Supports base64-encoded images in the conversation. - """ - - if isinstance(processing_class, ProcessorMixin): - prompt_text = processing_class.apply_chat_template( - conversation=conversation, - add_generation_prompt=add_generation_prompt, - tokenize=False, - ) - - images = [] - for msg in conversation: - for c in msg.get("content", []): - if c.get("type") == "image_url": - pil_img = _base64_to_pil(c["image_url"]["url"]) - images.append(pil_img) - - inputs = processing_class( - text=[prompt_text], - images=images if images else None, - return_tensors="pt", - add_special_tokens=add_special_tokens, - ) - return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() - - else: - prompt_ids : List[int] = processing_class.apply_chat_template( - conversation=conversation, - add_generation_prompt=add_generation_prompt, - ) - return prompt_ids,None,None - -def encode_text_with_processor( - text: str, - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], -) -> tuple[list[int], Any, Any]: - """ - Encode plain text and return token IDs, handling both tokenizer and processor. - """ - if isinstance(processing_class, ProcessorMixin): - inputs = processing_class( - text=[text], - images=None, - return_tensors="pt", - ) - input_ids = inputs["input_ids"][0].tolist() - image_grid = inputs.get("image_grid_thw", [None])[0].tolist() - pixel_values = inputs.get("pixel_values", [None]).tolist() - return input_ids, image_grid, pixel_values - else: - prompt_ids: list[int] = processing_class.encode( - text - ) - return prompt_ids, None, None class Environment(ABC): """ diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 3003dea17..2180378f8 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,5 +1,7 @@ from importlib.util import find_spec -from typing import Any, Callable +from typing import Any, Callable, List, Dict, Union + +from verifiers.utils.image_utils import _base64_to_pil import torch # type: ignore[unresolved-import] import torch.nn as nn # type: ignore[unresolved-import] @@ -7,7 +9,9 @@ AutoModelForCausalLM, AutoTokenizer, ) - +from transformers.tokenization_utils_base import ( # type: ignore[unresolved-import] + PreTrainedTokenizerBase, ProcessorMixin +) class _ForwardRedirection: """Implements the `forward-redirection`. @@ -111,3 +115,66 @@ def get_model_and_tokenizer( model = get_model(model_name, use_liger, model_kwargs) tokenizer = get_tokenizer(model_name) return model, tokenizer + +def encode_chat_with_processor( + conversation: List[Dict], + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + add_generation_prompt: bool = False, + add_special_tokens: bool = False, +) -> List[int]: + """ + Apply chat template and return token IDs, handling both tokenizer and processor. + Supports base64-encoded images in the conversation. + """ + + if isinstance(processing_class, ProcessorMixin): + prompt_text = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + + images = [] + for msg in conversation: + for c in msg.get("content", []): + if c.get("type") == "image_url": + pil_img = _base64_to_pil(c["image_url"]["url"]) + images.append(pil_img) + + inputs = processing_class( + text=[prompt_text], + images=images if images else None, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() + + else: + prompt_ids : List[int] = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + ) + return prompt_ids,None,None + +def encode_text_with_processor( + text: str, + processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], +) -> tuple[list[int], Any, Any]: + """ + Encode plain text and return token IDs, handling both tokenizer and processor. + """ + if isinstance(processing_class, ProcessorMixin): + inputs = processing_class( + text=[text], + images=None, + return_tensors="pt", + ) + input_ids = inputs["input_ids"][0].tolist() + image_grid = inputs.get("image_grid_thw", [None])[0].tolist() + pixel_values = inputs.get("pixel_values", [None]).tolist() + return input_ids, image_grid, pixel_values + else: + prompt_ids: list[int] = processing_class.encode( + text + ) + return prompt_ids, None, None \ No newline at end of file From 030dddca8042e1284cf6c52d53cb599933a159a6 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:28:01 +0200 Subject: [PATCH 41/58] add image utils --- verifiers/utils/image_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 verifiers/utils/image_utils.py diff --git a/verifiers/utils/image_utils.py b/verifiers/utils/image_utils.py new file mode 100644 index 000000000..51fb4ac26 --- /dev/null +++ b/verifiers/utils/image_utils.py @@ -0,0 +1,11 @@ +import base64 +from io import BytesIO +from PIL import Image + +def _base64_to_pil(data_uri: str) -> Image.Image: + """Convert a base64 data URI (data:image/...;base64,...) to a PIL Image.""" + if not data_uri.startswith("data:image"): + raise ValueError(f"Expected base64 image data URI, got: {data_uri[:30]}") + header, b64data = data_uri.split(",", 1) + image_data = base64.b64decode(b64data) + return Image.open(BytesIO(image_data)).convert("RGB") \ No newline at end of file From 50eb136c2269df2dc6ffa65c6fda66967a9cb269 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:41:41 +0200 Subject: [PATCH 42/58] add processor utils with lazy imports and no transformers to use in environment --- verifiers/envs/environment.py | 2 +- verifiers/trainers/grpo_trainer.py | 11 +---- verifiers/utils/image_utils.py | 11 ++++- verifiers/utils/model_utils.py | 69 +---------------------------- verifiers/utils/processor_utils.py | 71 ++++++++++++++++++++++++++++++ 5 files changed, 84 insertions(+), 80 deletions(-) create mode 100644 verifiers/utils/processor_utils.py diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e9575db2c..eb2cf9115 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -26,7 +26,7 @@ State, ) from verifiers.utils.message_utils import cleanup_messages, sanitize_tool_calls -from verifiers.utils.model_utils import encode_text_with_processor, encode_chat_with_processor +from verifiers.utils.processor_utils import encode_text_with_processor, encode_chat_with_processor if TYPE_CHECKING: from transformers.tokenization_utils_base import ( # type: ignore diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 1d2c1939b..90f6add14 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -53,7 +53,7 @@ from verifiers.trainers.async_dataloader_wrapper import AsyncDataLoaderWrapper from verifiers.trainers.grpo_config import GRPOConfig from verifiers.utils.logging_utils import print_prompt_completions_sample, serialize_for_wandb, extract_images - +from verifiers.utils.image_utils import pil_to_base64_url class RepeatSampler(Sampler): """ @@ -268,15 +268,6 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor: return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) return torch.max(tensor[~torch.isnan(tensor)]) -def pil_to_base64_url(pil_image) -> str: - """ - Convert a PIL image to a base64 URL string suitable for OpenAI/vLLM messages. - """ - buffered = BytesIO() - pil_image.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") - return f"data:image/png;base64,{img_str}" - class GRPOTrainer(Trainer): def __init__( self, diff --git a/verifiers/utils/image_utils.py b/verifiers/utils/image_utils.py index 51fb4ac26..ef061935c 100644 --- a/verifiers/utils/image_utils.py +++ b/verifiers/utils/image_utils.py @@ -8,4 +8,13 @@ def _base64_to_pil(data_uri: str) -> Image.Image: raise ValueError(f"Expected base64 image data URI, got: {data_uri[:30]}") header, b64data = data_uri.split(",", 1) image_data = base64.b64decode(b64data) - return Image.open(BytesIO(image_data)).convert("RGB") \ No newline at end of file + return Image.open(BytesIO(image_data)).convert("RGB") + +def pil_to_base64_url(pil_image) -> str: + """ + Convert a PIL image to a base64 URL string suitable for OpenAI/vLLM messages. + """ + buffered = BytesIO() + pil_image.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return f"data:image/png;base64,{img_str}" \ No newline at end of file diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 2180378f8..27a28c820 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,7 +1,6 @@ from importlib.util import find_spec from typing import Any, Callable, List, Dict, Union -from verifiers.utils.image_utils import _base64_to_pil import torch # type: ignore[unresolved-import] import torch.nn as nn # type: ignore[unresolved-import] @@ -9,9 +8,6 @@ AutoModelForCausalLM, AutoTokenizer, ) -from transformers.tokenization_utils_base import ( # type: ignore[unresolved-import] - PreTrainedTokenizerBase, ProcessorMixin -) class _ForwardRedirection: """Implements the `forward-redirection`. @@ -114,67 +110,4 @@ def get_model_and_tokenizer( ) -> tuple[Any, Any]: model = get_model(model_name, use_liger, model_kwargs) tokenizer = get_tokenizer(model_name) - return model, tokenizer - -def encode_chat_with_processor( - conversation: List[Dict], - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], - add_generation_prompt: bool = False, - add_special_tokens: bool = False, -) -> List[int]: - """ - Apply chat template and return token IDs, handling both tokenizer and processor. - Supports base64-encoded images in the conversation. - """ - - if isinstance(processing_class, ProcessorMixin): - prompt_text = processing_class.apply_chat_template( - conversation=conversation, - add_generation_prompt=add_generation_prompt, - tokenize=False, - ) - - images = [] - for msg in conversation: - for c in msg.get("content", []): - if c.get("type") == "image_url": - pil_img = _base64_to_pil(c["image_url"]["url"]) - images.append(pil_img) - - inputs = processing_class( - text=[prompt_text], - images=images if images else None, - return_tensors="pt", - add_special_tokens=add_special_tokens, - ) - return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() - - else: - prompt_ids : List[int] = processing_class.apply_chat_template( - conversation=conversation, - add_generation_prompt=add_generation_prompt, - ) - return prompt_ids,None,None - -def encode_text_with_processor( - text: str, - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], -) -> tuple[list[int], Any, Any]: - """ - Encode plain text and return token IDs, handling both tokenizer and processor. - """ - if isinstance(processing_class, ProcessorMixin): - inputs = processing_class( - text=[text], - images=None, - return_tensors="pt", - ) - input_ids = inputs["input_ids"][0].tolist() - image_grid = inputs.get("image_grid_thw", [None])[0].tolist() - pixel_values = inputs.get("pixel_values", [None]).tolist() - return input_ids, image_grid, pixel_values - else: - prompt_ids: list[int] = processing_class.encode( - text - ) - return prompt_ids, None, None \ No newline at end of file + return model, tokenizer \ No newline at end of file diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py new file mode 100644 index 000000000..3c01592e0 --- /dev/null +++ b/verifiers/utils/processor_utils.py @@ -0,0 +1,71 @@ +from verifiers.utils.image_utils import _base64_to_pil +from typing import Union, List, Dict, Any, Any, TYPE_CHECKING +from inspect import signature + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase, ProcessorMixin + +def encode_chat_with_processor( + conversation: List[Dict], + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], + add_generation_prompt: bool = False, + add_special_tokens: bool = False, +) -> List[int]: + """ + Apply chat template and return token IDs, handling both tokenizer and processor. + Supports base64-encoded images in the conversation. + """ + + sig = signature(processing_class.__call__ if hasattr(processing_class, "__call__") else processing_class) # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature + if "images" in sig.parameters: + prompt_text = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + tokenize=False, + ) + + images = [] + for msg in conversation: + for c in msg.get("content", []): + if c.get("type") == "image_url": + pil_img = _base64_to_pil(c["image_url"]["url"]) + images.append(pil_img) + + inputs = processing_class( + text=[prompt_text], + images=images if images else None, + return_tensors="pt", + add_special_tokens=add_special_tokens, + ) + return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() + + else: + prompt_ids : List[int] = processing_class.apply_chat_template( + conversation=conversation, + add_generation_prompt=add_generation_prompt, + ) + return prompt_ids,None,None + +def encode_text_with_processor( + text: str, + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], +) -> tuple[list[int], Any, Any]: + """ + Encode plain text and return token IDs, handling both tokenizer and processor. + """ + sig = signature(processing_class.__call__ if hasattr(processing_class, "__call__") else processing_class) # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature + if "images" in sig.parameters: + inputs = processing_class( + text=[text], + images=None, + return_tensors="pt", + ) + input_ids = inputs["input_ids"][0].tolist() + image_grid = inputs.get("image_grid_thw", [None])[0].tolist() + pixel_values = inputs.get("pixel_values", [None]).tolist() + return input_ids, image_grid, pixel_values + else: + prompt_ids: list[int] = processing_class.encode( + text + ) + return prompt_ids, None, None \ No newline at end of file From 653efb64e5490110651e245bdc0fa88f382ec7d9 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:44:50 +0200 Subject: [PATCH 43/58] clean type checking to avoid transformers --- verifiers/envs/environment.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index eb2cf9115..03d9d14e1 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -805,7 +805,7 @@ def process_chat_format_vllm( prompt: list[ChatMessage], completion: list[ChatMessage], state: State, - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: """ @@ -929,7 +929,7 @@ def process_completion_format_vllm( prompt: str, completion: str, state: State, - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: """ @@ -1011,7 +1011,7 @@ def process_env_results_vllm( completions: list[Messages], states: list[State], rewards: list[float], - processing_class: Union[PreTrainedTokenizerBase, ProcessorMixin], + processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], max_seq_len: int = -1, mask_env_responses: bool = False, mask_truncated_completions: bool = False, From 3512a3b54195d4923d605736ec25640a18ddd827 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 22:54:06 +0200 Subject: [PATCH 44/58] fix issue with not callable Processor --- verifiers/utils/processor_utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 3c01592e0..2cf3fa5b4 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -4,7 +4,16 @@ if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase, ProcessorMixin - + +def supports_images(obj): # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature + if callable(obj): + try: + sig = signature(obj) + return "images" in sig.parameters + except TypeError: + return False + return False + def encode_chat_with_processor( conversation: List[Dict], processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], @@ -16,8 +25,7 @@ def encode_chat_with_processor( Supports base64-encoded images in the conversation. """ - sig = signature(processing_class.__call__ if hasattr(processing_class, "__call__") else processing_class) # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature - if "images" in sig.parameters: + if supports_images(processing_class): prompt_text = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, @@ -53,8 +61,7 @@ def encode_text_with_processor( """ Encode plain text and return token IDs, handling both tokenizer and processor. """ - sig = signature(processing_class.__call__ if hasattr(processing_class, "__call__") else processing_class) # work as a replacement for if isinstance(processing_class, ProcessorMixin), because we already type check with lazy importfrom inspect import signature - if "images" in sig.parameters: + if supports_images(processing_class): inputs = processing_class( text=[text], images=None, From 491a8e41824af7c6c63ae63079a1bcee7555e076 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 23:04:57 +0200 Subject: [PATCH 45/58] fix ruff --- verifiers/envs/environment.py | 2 +- verifiers/trainers/grpo_trainer.py | 11 ++++------- verifiers/utils/model_utils.py | 2 +- verifiers/utils/processor_utils.py | 2 +- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 03d9d14e1..df1584c33 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor from copy import deepcopy -from typing import TYPE_CHECKING, Literal, Union, List, Dict, Any +from typing import TYPE_CHECKING, Literal, Union from datasets import Dataset from openai import AsyncOpenAI, OpenAI from verifiers.parsers.parser import Parser diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 90f6add14..b2f8c408b 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -11,9 +11,6 @@ import numpy as np import torch # type: ignore[unresolved-import] import wandb # type: ignore[unresolved-import] -import base64 -from io import BytesIO -import transformers from accelerate.utils import ( # type: ignore[unresolved-import] broadcast_object_list, gather_object, @@ -21,7 +18,6 @@ ) from peft import PeftConfig, get_peft_model # type: ignore[unresolved-import] from torch.utils.data import DataLoader, Sampler # type: ignore[unresolved-import] -from transformers import AutoModelForCausalLM # type: ignore[unresolved-import] from transformers.integrations.deepspeed import ( # type: ignore[unresolved-import] is_deepspeed_zero3_enabled, ) @@ -35,6 +31,7 @@ from transformers.trainer_callback import ( # type: ignore[unresolved-import] TrainerCallback, ) +from transformers import ProcessorMixin, AutoModelForCausalLM # type: ignore[unresolved-import] from transformers.trainer_utils import seed_worker # type: ignore[unresolved-import] from trl.models import ( # type: ignore[unresolved-import] create_reference_model, @@ -536,9 +533,9 @@ def data_collator(features): elif is_deepspeed_zero3_enabled(): model_id = model.config._name_or_path model_init_kwargs = {"torch_dtype": "auto"} - config = AutoConfig.from_pretrained(model_id) - architecture = getattr(transformers, config.architectures[0]) - self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs) + self.ref_model = AutoModelForCausalLM.from_pretrained( + model_id, **model_init_kwargs + ) elif is_peft_model(model): # If PEFT is used, the reference model is not needed since the adapter can be disabled # to revert to the initial model. diff --git a/verifiers/utils/model_utils.py b/verifiers/utils/model_utils.py index 27a28c820..47f3e5748 100644 --- a/verifiers/utils/model_utils.py +++ b/verifiers/utils/model_utils.py @@ -1,5 +1,5 @@ from importlib.util import find_spec -from typing import Any, Callable, List, Dict, Union +from typing import Any, Callable import torch # type: ignore[unresolved-import] diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 2cf3fa5b4..e1f850eb8 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -1,5 +1,5 @@ from verifiers.utils.image_utils import _base64_to_pil -from typing import Union, List, Dict, Any, Any, TYPE_CHECKING +from typing import Union, List, Dict, Any, TYPE_CHECKING from inspect import signature if TYPE_CHECKING: From 6a32a10c6e5e086c503a4cbdba31b9f216fd98c8 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 23:16:31 +0200 Subject: [PATCH 46/58] fix py precommit checks for typing --- verifiers/utils/processor_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index e1f850eb8..337c87926 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -1,5 +1,5 @@ from verifiers.utils.image_utils import _base64_to_pil -from typing import Union, List, Dict, Any, TYPE_CHECKING +from typing import Union, List, Dict, Any, Optional, TYPE_CHECKING from inspect import signature if TYPE_CHECKING: @@ -19,7 +19,7 @@ def encode_chat_with_processor( processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], add_generation_prompt: bool = False, add_special_tokens: bool = False, -) -> List[int]: +) -> tuple[list[int], Any, Any]: """ Apply chat template and return token IDs, handling both tokenizer and processor. Supports base64-encoded images in the conversation. From 235a452b518977f00a43e09200c44c7c05bed3ba Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 23:20:31 +0200 Subject: [PATCH 47/58] fix py precommit checks for typing --- verifiers/envs/environment.py | 2 +- verifiers/utils/processor_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index df1584c33..c807570ea 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -807,7 +807,7 @@ def process_chat_format_vllm( state: State, processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, - ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: + ) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]: """ Process chat format conversations using incremental prefixes. """ diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 337c87926..ac52f82ec 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -1,5 +1,5 @@ from verifiers.utils.image_utils import _base64_to_pil -from typing import Union, List, Dict, Any, Optional, TYPE_CHECKING +from typing import Union, List, Dict, Any, TYPE_CHECKING from inspect import signature if TYPE_CHECKING: From 555c268172eaf78a9387a26fc3c8d011b972eb95 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 23:22:20 +0200 Subject: [PATCH 48/58] fix py precommit checks for typing --- verifiers/envs/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index c807570ea..608aa0cc2 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -931,7 +931,7 @@ def process_completion_format_vllm( state: State, processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"], mask_env_responses: bool = False, - ) -> tuple[list[int], list[int], list[int], list[int], list[float]]: + ) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]: """ Process completion format conversations using incremental prefixes. """ From 80fb0bb8a4660342339c6a0eb1562ec48f91a940 Mon Sep 17 00:00:00 2001 From: BLE Date: Fri, 3 Oct 2025 23:25:17 +0200 Subject: [PATCH 49/58] fix py precommit checks for typing --- verifiers/trainers/async_batch_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index d56bd179d..a75701243 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -4,7 +4,7 @@ import threading import time from collections import deque -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, Field @@ -38,7 +38,7 @@ class BatchResult(BaseModel): default_factory=list ) # Store completions for logging prompts: list[Any] = Field(default_factory=list) # Store prompts for logging - answers : list[Any] + answers : Optional[list[Any]] class AsyncBatchGenerator: From 378f8850d901d3cb0823a625288d916794e182d9 Mon Sep 17 00:00:00 2001 From: ulrick Date: Sun, 5 Oct 2025 08:23:11 +0000 Subject: [PATCH 50/58] WIP : fix issue with validators in training by modifying GenerateOutputs typing --- verifiers/envs/environment.py | 1 + verifiers/trainers/async_batch_generator.py | 3 +- verifiers/trainers/grpo_trainer.py | 37 ++++++------- verifiers/types.py | 4 +- verifiers/utils/message_utils.py | 57 ++++++++++----------- verifiers/utils/processor_utils.py | 3 +- 6 files changed, 53 insertions(+), 52 deletions(-) diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 608aa0cc2..553961a3f 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -429,6 +429,7 @@ async def a_generate( reward=[], metrics={}, ) + n = len(results.prompt) # Resolve concurrency knobs diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index a75701243..5aeb1fe84 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -11,7 +11,6 @@ from verifiers import GenerateOutputs from verifiers.types import ProcessedOutputs - class BatchRequest(BaseModel): """Request for batch generation""" @@ -265,6 +264,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: """ # Call environment generation self.is_generating = True + env_results = await self.env.a_generate( request.env_inputs, client=self.client, @@ -273,6 +273,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: score_rollouts=True, max_concurrent=request.max_concurrent, ) + self.is_generating = False # Extract all reward-related keys diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b2f8c408b..ecab8e500 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -988,46 +988,48 @@ def _ids_to_tensors( def _gather_batch_data(self, batch_offset: int = 0): """ Gather batch data from all processes and convert PIL images in prompts to base64 image_url. + If prompt already has an image_url, leave it as is. + If prompt has a placeholder {"type": "image"} and x has a PIL image, convert it. """ batches = self._async_dataloader.peek_ahead(batch_offset) - + if batch_offset == 0: batch = batches[0] if batches else None else: batch = batches[batch_offset - 1] if batches else None - + if batch is None: return [], [], [], [] - + if isinstance(batch, dict): batch = [batch] - + prompts = [] for x in batch: prompt = x["prompt"] + image = x.get("image") for message in prompt: content = message.get("content", []) if isinstance(content, list): for c in content: - if isinstance(c, dict): - if c.get("type") == "image": # Convert only if not already base64 - if "image_url" not in c: - if "image" in x: # only convert if PIL image exists - img_url = pil_to_base64_url(x["image"]) - c.clear() - c.update({ - "type": "image_url", - "image_url": {"url": img_url} - }) - elif c.get("type") == "image_url": # Already base64, leave as is - pass + if not isinstance(c, dict): + continue + if c.get("type") == "image_url": # If already base64, skip + continue + if c.get("type") == "image" and "image" in x and x["image"] is not None: # If placeholder and we have a PIL image, convert + img_url = pil_to_base64_url(x["image"]) + c.clear() + c.update({ + "type": "image_url", + "image_url": {"url": img_url} + }) elif isinstance(content, str): pass else: print("Unknown content type:", type(content)) prompts.append(prompt) - + answers = [x["answer"] for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] @@ -1036,7 +1038,6 @@ def _gather_batch_data(self, batch_offset: int = 0): all_answers = gather_object(answers) all_tasks = gather_object(tasks) all_infos = gather_object(infos) - return all_prompts, all_answers, all_tasks, all_infos def _prepare_inputs( # type: ignore diff --git a/verifiers/types.py b/verifiers/types.py index ee2f7e56d..670f789c3 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -55,8 +55,8 @@ class GenerateInputs(BaseModel): class GenerateOutputs(BaseModel): """Pydantic model for generation outputs.""" - prompt: list[Messages] - completion: list[Messages] + prompt: list[list[dict]] + completion: list[list[dict]] answer: list[str] state: list[State] info: list[Info] diff --git a/verifiers/utils/message_utils.py b/verifiers/utils/message_utils.py index ce4dd148e..8661378e0 100644 --- a/verifiers/utils/message_utils.py +++ b/verifiers/utils/message_utils.py @@ -60,38 +60,37 @@ def cleanup_message(message: ChatMessage) -> ChatMessage: if isinstance(content, str): new_message["content"] = content - - else : - for c in content: - c_dict = dict(c) - c_type = c_dict.get("type") - - if c_type == "text": - if c_dict.get("text") is not None: - new_message["content"].append({ - "type": "text", - "text": c_dict["text"] - }) - - elif c_type == "image_url": - if "image_url" in c_dict: - new_message["content"].append({ - "type": "image_url", - "image_url": c_dict["image_url"] - }) - - elif c_type == "input_audio": - if "input_audio" in c_dict: - new_message["content"].append({ - "type": "input_audio", - "input_audio": c_dict["input_audio"] - }) + return cast(ChatMessage, new_message) - else: - new_message["content"].append(c_dict) + for c in content: + c_dict = dict(c) + c_type = c_dict.get("type") + + if c_type == "text": + if c_dict.get("text") is not None: + new_message["content"].append({ + "type": "text", + "text": c_dict["text"] + }) + + elif c_type == "image_url": + if "image_url" in c_dict: + new_message["content"].append({ + "type": "image_url", + "image_url": c_dict["image_url"] + }) + + elif c_type == "input_audio": + if "input_audio" in c_dict: + new_message["content"].append({ + "type": "input_audio", + "input_audio": c_dict["input_audio"] + }) - return cast(ChatMessage, new_message) + else: + new_message["content"].append(c_dict) + return cast(ChatMessage, new_message) def cleanup_messages(messages: Messages) -> Messages: if isinstance(messages, str): diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index ac52f82ec..6f43484cf 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -24,14 +24,13 @@ def encode_chat_with_processor( Apply chat template and return token IDs, handling both tokenizer and processor. Supports base64-encoded images in the conversation. """ - + if supports_images(processing_class): prompt_text = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, tokenize=False, ) - images = [] for msg in conversation: for c in msg.get("content", []): From 0fde4d020177b0b026f0756c0b4dcdb6c47fb316 Mon Sep 17 00:00:00 2001 From: BLE Date: Sun, 5 Oct 2025 10:29:38 +0200 Subject: [PATCH 51/58] fix ruff --- verifiers/trainers/grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index ecab8e500..b33495a8d 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -1007,7 +1007,6 @@ def _gather_batch_data(self, batch_offset: int = 0): prompts = [] for x in batch: prompt = x["prompt"] - image = x.get("image") for message in prompt: content = message.get("content", []) if isinstance(content, list): From 193346c054f2f759a732c7d6c51c64cb67971cad Mon Sep 17 00:00:00 2001 From: ulrick Date: Mon, 6 Oct 2025 18:54:54 +0000 Subject: [PATCH 52/58] WIP : working on Hindi OCR by fixing KL olds pixel values, but cannot converge and issue when more batch size --- verifiers/trainers/async_batch_generator.py | 1 + verifiers/trainers/grpo_trainer.py | 30 ++++++++++++++------- verifiers/utils/processor_utils.py | 23 +++++++++++++++- 3 files changed, 44 insertions(+), 10 deletions(-) diff --git a/verifiers/trainers/async_batch_generator.py b/verifiers/trainers/async_batch_generator.py index 5aeb1fe84..b5debe8c1 100644 --- a/verifiers/trainers/async_batch_generator.py +++ b/verifiers/trainers/async_batch_generator.py @@ -283,6 +283,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult: for k in env_results.metrics: all_reward_dict[k] = env_results.metrics[k] + # Process results processed_results = self.env.process_env_results_vllm( prompts=env_results.prompt, diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index ecab8e500..c9d06b657 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -303,6 +303,9 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True + + self.processor = processing_class + self._debug_batch_counter = 0 # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -768,8 +771,6 @@ def _get_last_hidden_state( logits_to_keep, pixel_values=None, image_grid_thw=None, - pixel_attention_mask=None, - image_sizes=None, ): if is_peft_model(unwrapped_model): unwrapped_model = unwrapped_model.base_model.model @@ -777,6 +778,11 @@ def _get_last_hidden_state( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + if pixel_values is not None: + model_inputs["pixel_values"] = pixel_values + if image_grid_thw is not None: + model_inputs["image_grid_thw"] = image_grid_thw + last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state # Exclude the last value: it corresponds to the next token pred last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) @@ -799,6 +805,8 @@ def _get_per_token_logps( 0 ) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] + raw_batch_data = {"input_ids":input_ids,"attention_mask":attention_mask,"pixel_values":pixel_values,"image_grid_thw":image_grid_thw} + for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] @@ -807,10 +815,9 @@ def _get_per_token_logps( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[i : i + batch_size] - start_pixel_idx = image_grid_thw[:i].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: i + batch_size].prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + model_inputs["pixel_values"] = pixel_values[i : i + batch_size] + model_inputs["image_grid_thw"]= image_grid_thw[i : i + batch_size] + model_inputs["pixel_values"] = model_inputs["pixel_values"].reshape(-1, model_inputs["pixel_values"].shape[-1]) elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[i : i + batch_size] @@ -819,6 +826,7 @@ def _get_per_token_logps( # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded model_inputs["logits_to_keep"] = logits_to_keep + 1 + print(model_inputs["pixel_values"].shape) logits = model(**model_inputs).logits # Exclude the last value: it corresponds to the next token pred @@ -1055,7 +1063,6 @@ def _prepare_inputs( # type: ignore self.accelerator.wait_for_everyone() # inputs = list of dicts for all gradient accumulation steps generate_every = self.gradient_accumulation_steps * self.num_iterations - # Check if we need to generate new completions if self._step % generate_every == 0 or self._buffered_inputs is None: # Update weights to vLLM if needed @@ -1241,6 +1248,9 @@ def _prepare_inputs( # type: ignore if has_images: pixel_values = torch.stack(pixel_values_list, dim=0) image_grid_thw = torch.stack(image_grid_list, dim=0) + else : + pixel_values = None + image_grid_thw = None # Truncate if needed if self.max_seq_len is not None and input_ids.size(1) > self.max_seq_len: @@ -1280,8 +1290,10 @@ def _prepare_inputs( # type: ignore self.model, input_ids, attention_mask, - logits_to_keep, - batch_size=self.per_device_train_batch_size, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + logits_to_keep=logits_to_keep, + batch_size=self.per_device_train_batch_size ) # Concatenate all data for shuffling diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 6f43484cf..6ce8391f7 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -1,6 +1,7 @@ from verifiers.utils.image_utils import _base64_to_pil from typing import Union, List, Dict, Any, TYPE_CHECKING from inspect import signature +import json if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase, ProcessorMixin @@ -44,9 +45,28 @@ def encode_chat_with_processor( return_tensors="pt", add_special_tokens=add_special_tokens, ) - return inputs["input_ids"][0].tolist(), inputs["image_grid_thw"][0].tolist(), inputs["pixel_values"].tolist() + input_ids_list = inputs["input_ids"][0].tolist() + image_grid_list = inputs["image_grid_thw"][0].tolist() + pixel_values_list = inputs["pixel_values"].tolist() + + # --- DEBUG CODE: Store conversation and encodings --- + debug_data = { + "conversation": conversation, + "input_ids": input_ids_list, + "image_grid_thw": image_grid_list, + "pixel_values_list": pixel_values_list, + } + + # Append the data to a JSON Lines file + file_path = "chat_encodings.jsonl" + with open(file_path, "a") as f: + f.write(json.dumps(debug_data) + "\n") + # ---------------------------------------------------- + + return input_ids_list, image_grid_list, pixel_values_list else: + print("NO_SUPPORT_CHAT") prompt_ids : List[int] = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, @@ -71,6 +91,7 @@ def encode_text_with_processor( pixel_values = inputs.get("pixel_values", [None]).tolist() return input_ids, image_grid, pixel_values else: + print("NO_SUPPORT_ENCODE") prompt_ids: list[int] = processing_class.encode( text ) From aea2b9de114a6fba9422fd9b019dc46f1cdcccdd Mon Sep 17 00:00:00 2001 From: ulrick Date: Mon, 6 Oct 2025 20:39:48 +0000 Subject: [PATCH 53/58] correct trainer, almost working, fix pixel values and schuffle for different data in same batch with pixel values --- verifiers/trainers/grpo_trainer.py | 48 ++++++++++++++---------------- verifiers/utils/processor_utils.py | 16 ---------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index c9d06b657..0375187e1 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -210,31 +210,29 @@ def split_tensor_dict( ] -def shuffle_tensor_dict( - tensor_dict: dict[str, Optional[torch.Tensor]], -) -> dict[str, Optional[torch.Tensor]]: +def shuffle_dict_with_lists( + data_dict: Dict[str, Optional[Union[torch.Tensor, List]]], +) -> Dict[str, Optional[Union[torch.Tensor, List]]]: """ - Shuffles a dictionary of tensors along the first dimension in unison. - - Example: - >>> x = torch.arange(6).reshape(3, 2) - >>> y = torch.arange(3).reshape(3, 1) - >>> tensor_dict = {"x": x, "y": y} - >>> shuffle_tensor_dict(tensor_dict) - {'x': tensor([[2, 3], - [0, 1], - [4, 5]]), - 'y': tensor([[1], - [0], - [2]])} + Shuffles a dictionary of tensors and/or lists along the first dimension in unison since pixel values can't be a schufflable tensor at the moment """ - first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) - batch_size = first_tensor.shape[0] + first_item = next(item for item in data_dict.values() if item is not None) + batch_size = len(first_item) + permutation = torch.randperm(batch_size) - return { - key: tensor[permutation] if tensor is not None else None - for key, tensor in tensor_dict.items() - } + + shuffled_dict = {} + for key, value in data_dict.items(): + if value is None: + shuffled_dict[key] = None + elif isinstance(value, torch.Tensor): + shuffled_dict[key] = value[permutation] + elif isinstance(value, list): + shuffled_dict[key] = [value[i] for i in permutation] + else: + shuffled_dict[key] = value + return shuffled_dict + def nanmin(tensor: torch.Tensor) -> torch.Tensor: """ @@ -817,7 +815,7 @@ def _get_per_token_logps( if image_grid_thw is not None and pixel_values is not None: model_inputs["pixel_values"] = pixel_values[i : i + batch_size] model_inputs["image_grid_thw"]= image_grid_thw[i : i + batch_size] - model_inputs["pixel_values"] = model_inputs["pixel_values"].reshape(-1, model_inputs["pixel_values"].shape[-1]) + model_inputs["pixel_values"] = torch.cat(model_inputs["pixel_values"], dim=0) elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[i : i + batch_size] @@ -1246,7 +1244,7 @@ def _prepare_inputs( # type: ignore attention_mask = pad(attention_mask_list, padding_side="right") # type: ignore if has_images: - pixel_values = torch.stack(pixel_values_list, dim=0) + pixel_values = pixel_values_list image_grid_thw = torch.stack(image_grid_list, dim=0) else : pixel_values = None @@ -1308,7 +1306,7 @@ def _prepare_inputs( # type: ignore full_batch["image_grid_thw"] = image_grid_thw # Shuffle and split for gradient accumulation - full_batch = shuffle_tensor_dict(full_batch) + full_batch = shuffle_dict_with_lists(full_batch) self._buffered_inputs = split_tensor_dict( full_batch, self.gradient_accumulation_steps ) diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 6ce8391f7..765d7a01d 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -49,24 +49,9 @@ def encode_chat_with_processor( image_grid_list = inputs["image_grid_thw"][0].tolist() pixel_values_list = inputs["pixel_values"].tolist() - # --- DEBUG CODE: Store conversation and encodings --- - debug_data = { - "conversation": conversation, - "input_ids": input_ids_list, - "image_grid_thw": image_grid_list, - "pixel_values_list": pixel_values_list, - } - - # Append the data to a JSON Lines file - file_path = "chat_encodings.jsonl" - with open(file_path, "a") as f: - f.write(json.dumps(debug_data) + "\n") - # ---------------------------------------------------- - return input_ids_list, image_grid_list, pixel_values_list else: - print("NO_SUPPORT_CHAT") prompt_ids : List[int] = processing_class.apply_chat_template( conversation=conversation, add_generation_prompt=add_generation_prompt, @@ -91,7 +76,6 @@ def encode_text_with_processor( pixel_values = inputs.get("pixel_values", [None]).tolist() return input_ids, image_grid, pixel_values else: - print("NO_SUPPORT_ENCODE") prompt_ids: list[int] = processing_class.encode( text ) From 1ed9ca6856f101e5e30d00ff26b752f497def0e3 Mon Sep 17 00:00:00 2001 From: Ulrick Date: Wed, 8 Oct 2025 20:53:03 +0000 Subject: [PATCH 54/58] fix grpo trainer --- verifiers/trainers/grpo_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index 0375187e1..56a6d21b6 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -467,11 +467,11 @@ def filter_by_prompt_length(example, processing_class): return len(prompt_ids) <= max_length original_size = len(train_dataset) - train_dataset = train_dataset.filter( - filter_by_prompt_length, - num_proc=self.max_data_workers, - fn_kwargs={"processing_class": processing_class}, - ) + #train_dataset = train_dataset.filter( + # filter_by_prompt_length, + # num_proc=self.max_data_workers, + # fn_kwargs={"processing_class": processing_class}, + #) filtered_size = len(train_dataset) if filtered_size < original_size: self.logger.info( From 016cc35e4e35fe9e0573761c40ca3e7fbfd2545f Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 8 Oct 2025 22:53:46 +0200 Subject: [PATCH 55/58] clean debug and fix ruff --- verifiers/trainers/grpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index b608a2893..739271001 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -301,9 +301,6 @@ def __init__( # Suppress irrelevant warning model.warnings_issued["estimate_tokens"] = True - - self.processor = processing_class - self._debug_batch_counter = 0 # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -803,7 +800,6 @@ def _get_per_token_logps( 0 ) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] - raw_batch_data = {"input_ids":input_ids,"attention_mask":attention_mask,"pixel_values":pixel_values,"image_grid_thw":image_grid_thw} for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] From fff88bfcff7f79f74809359f4c44834e54a5c9b8 Mon Sep 17 00:00:00 2001 From: BLE Date: Wed, 8 Oct 2025 22:55:38 +0200 Subject: [PATCH 56/58] clean ruff --- verifiers/utils/processor_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index 765d7a01d..e625424da 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -1,7 +1,6 @@ from verifiers.utils.image_utils import _base64_to_pil from typing import Union, List, Dict, Any, TYPE_CHECKING from inspect import signature -import json if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase, ProcessorMixin From d30bfdc6e866e1fe49a492e17ba7ea516bce976d Mon Sep 17 00:00:00 2001 From: Ulrick Date: Sun, 19 Oct 2025 19:48:23 +0000 Subject: [PATCH 57/58] adapt code to multiple images in a single sequence --- verifiers/trainers/grpo_trainer.py | 81 ++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index f6a7be52a..f51cc882a 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -989,16 +989,27 @@ def _ids_to_tensors( def _gather_batch_data(self, batch_offset: int = 0): """ - Gather batch data from all processes and convert PIL images in prompts to base64 image_url. - If prompt already has an image_url, leave it as is. - If prompt has a placeholder {"type": "image"} and x has a PIL image, convert it. + Gather batch data from all processes and convert PIL images (single or multiple) + in prompts to base64 image_url. + + Handles: + - 'image': single PIL.Image + - 'images': list of PIL.Image objects (matched in order to placeholders) + - placeholders {"type": "image"} + - existing image_url (skipped) + - direct PIL.Image objects inside content """ + batches = self._async_dataloader.peek_ahead(batch_offset) - if batch_offset == 0: - batch = batches[0] if batches else None - else: - batch = batches[batch_offset - 1] if batches else None + if not batches: + return [], [], [], [] + + batch = ( + batches[0] + if batch_offset == 0 + else batches[batch_offset - 1] if batch_offset - 1 < len(batches) else None + ) if batch is None: return [], [], [], [] @@ -1007,31 +1018,56 @@ def _gather_batch_data(self, batch_offset: int = 0): batch = [batch] prompts = [] + for x in batch: - prompt = x["prompt"] + prompt = x.get("prompt", []) + single_image = x.get("image", None) + multiple_images = x.get("images", []) + img_index = 0 # track which image in x["images"] we're up to + for message in prompt: content = message.get("content", []) - if isinstance(content, list): - for c in content: - if not isinstance(c, dict): - continue - if c.get("type") == "image_url": # If already base64, skip - continue - if c.get("type") == "image" and "image" in x and x["image"] is not None: # If placeholder and we have a PIL image, convert - img_url = pil_to_base64_url(x["image"]) + if not isinstance(content, list): + continue + + for i, c in enumerate(content): + if hasattr(c, "save"): + img_url = pil_to_base64_url(c) + content[i] = { + "type": "image_url", + "image_url": {"url": img_url}, + } + continue + + if not isinstance(c, dict): + continue + + ctype = c.get("type") + + if ctype == "image_url": #already an image_url -> skip + continue + + if ctype == "image": # placeholder and we have an image list + # Prefer multiple_images if available + if multiple_images and img_index < len(multiple_images): + pil_img = multiple_images[img_index] + img_index += 1 + elif single_image is not None: + pil_img = single_image + else: + pil_img = None + + if pil_img is not None: + img_url = pil_to_base64_url(pil_img) c.clear() c.update({ "type": "image_url", - "image_url": {"url": img_url} + "image_url": {"url": img_url}, }) - elif isinstance(content, str): - pass - else: - print("Unknown content type:", type(content)) prompts.append(prompt) - answers = [x["answer"] for x in batch] + answers = [x.get("answer") for x in batch] tasks = [x.get("task", "default") for x in batch] infos = [x.get("info", {}) for x in batch] @@ -1039,6 +1075,7 @@ def _gather_batch_data(self, batch_offset: int = 0): all_answers = gather_object(answers) all_tasks = gather_object(tasks) all_infos = gather_object(infos) + return all_prompts, all_answers, all_tasks, all_infos def _prepare_inputs( # type: ignore From 0ed1ff624a81762cc76e4f31b94e9f5e2b93592d Mon Sep 17 00:00:00 2001 From: Ulrick Date: Wed, 22 Oct 2025 17:22:05 +0000 Subject: [PATCH 58/58] add behavior for multi images in prompt --- verifiers/trainers/grpo_trainer.py | 5 +++-- verifiers/types.py | 2 +- verifiers/utils/processor_utils.py | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/verifiers/trainers/grpo_trainer.py b/verifiers/trainers/grpo_trainer.py index f51cc882a..aa4cd50bf 100644 --- a/verifiers/trainers/grpo_trainer.py +++ b/verifiers/trainers/grpo_trainer.py @@ -800,7 +800,7 @@ def _get_per_token_logps( 0 ) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] - + for i in range(0, input_ids.size(0), batch_size): input_ids_batch = input_ids[i : i + batch_size] attention_mask_batch = attention_mask[i : i + batch_size] @@ -812,6 +812,7 @@ def _get_per_token_logps( model_inputs["pixel_values"] = pixel_values[i : i + batch_size] model_inputs["image_grid_thw"]= image_grid_thw[i : i + batch_size] model_inputs["pixel_values"] = torch.cat(model_inputs["pixel_values"], dim=0) + model_inputs["image_grid_thw"] = model_inputs["image_grid_thw"].reshape(-1, *model_inputs["image_grid_thw"].shape[2:]) elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[i : i + batch_size] @@ -820,7 +821,6 @@ def _get_per_token_logps( # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded model_inputs["logits_to_keep"] = logits_to_keep + 1 - print(model_inputs["pixel_values"].shape) logits = model(**model_inputs).logits # Exclude the last value: it corresponds to the next token pred @@ -835,6 +835,7 @@ def _get_per_token_logps( logits, input_ids_batch ) # compute logprobs for the input tokens all_logps.append(logps) + return torch.cat(all_logps, dim=0) def _move_model_to_vllm(self): diff --git a/verifiers/types.py b/verifiers/types.py index 670f789c3..6dc3068ce 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -84,7 +84,7 @@ class ProcessedOutputs(BaseModel): prompt_ids: list[list[int]] prompt_mask: list[list[int]] - image_grid_thw: Optional[list[Optional[list[int]]]] = None + image_grid_thw: Optional[list[Optional[list[list[int]]]]] = None pixel_values: Optional[list[Optional[list[list[float]]]]] = None completion_ids: list[list[int]] completion_mask: list[list[int]] diff --git a/verifiers/utils/processor_utils.py b/verifiers/utils/processor_utils.py index e625424da..890101426 100644 --- a/verifiers/utils/processor_utils.py +++ b/verifiers/utils/processor_utils.py @@ -45,7 +45,7 @@ def encode_chat_with_processor( add_special_tokens=add_special_tokens, ) input_ids_list = inputs["input_ids"][0].tolist() - image_grid_list = inputs["image_grid_thw"][0].tolist() + image_grid_list = inputs["image_grid_thw"].tolist() pixel_values_list = inputs["pixel_values"].tolist() return input_ids_list, image_grid_list, pixel_values_list @@ -71,7 +71,7 @@ def encode_text_with_processor( return_tensors="pt", ) input_ids = inputs["input_ids"][0].tolist() - image_grid = inputs.get("image_grid_thw", [None])[0].tolist() + image_grid = inputs.get("image_grid_thw", [None]).tolist() pixel_values = inputs.get("pixel_values", [None]).tolist() return input_ids, image_grid, pixel_values else: