Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
0195a06
add first version to deal with image on verifiers, dealt with environ…
UlrickBL Sep 17, 2025
7ad272a
fix some issues with trainer
UlrickBL Sep 21, 2025
eb31d6e
fix filter by prompt length to deal with images
UlrickBL Sep 21, 2025
7d8051f
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Sep 21, 2025
8307874
WIP : fix the dataset filtering with images, working. Arrived at the …
UlrickBL Sep 21, 2025
1a5a662
add batch treatment with base64 for openAI format with vllm
UlrickBL Sep 21, 2025
7ea887f
Merge branch 'add_vlm_support' of github.com:UlrickBL/verifiers_vlm i…
UlrickBL Sep 21, 2025
652a7c3
WIP async generation working, issue in processing for now
UlrickBL Sep 21, 2025
1c4ce78
fix ruff issues
UlrickBL Sep 21, 2025
121913f
fix ruff issues
UlrickBL Sep 21, 2025
7460db2
WIP : deal with after rollout : processing the input and output and p…
UlrickBL Sep 21, 2025
5219e08
WIP : pass pixel values and image grid all way long to prepare input,…
UlrickBL Sep 22, 2025
eac0ccf
WIP : fix encoding and treatment, issue with pydantic in ProcessedOut…
UlrickBL Sep 23, 2025
e000737
fix until compute loss issue
UlrickBL Sep 23, 2025
f31ddb1
fix pixel values
UlrickBL Sep 23, 2025
d0471ff
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Sep 23, 2025
5e3ac99
WIP : try to fix compute loss
UlrickBL Sep 23, 2025
b44d3f9
WIP : working until self._get_per_token_logps model(**model_inputs).l…
UlrickBL Sep 23, 2025
b4ddd58
FIX Batch for text data while keeping string for image, however, issu…
UlrickBL Sep 23, 2025
dd4ed6d
fix issue with shape
UlrickBL Sep 24, 2025
076395d
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Sep 24, 2025
58a39ce
check pixel values issue
UlrickBL Sep 24, 2025
ba0302d
WIP : VL training working end to end, but rebus too complex -> need t…
UlrickBL Sep 25, 2025
3004416
update typing
UlrickBL Sep 28, 2025
5209a04
add image logging and answer logging in wandb to improve data diging
UlrickBL Sep 28, 2025
47c1df0
make code robust to text only
UlrickBL Sep 28, 2025
1b515a8
logging of image and answer working + full training works, maybe issu…
UlrickBL Sep 29, 2025
277ed43
fix ruff
UlrickBL Oct 1, 2025
c6c0d67
fix ruff
UlrickBL Oct 1, 2025
6652dac
fix ruff
UlrickBL Oct 1, 2025
6dc9d40
fix ruff
UlrickBL Oct 1, 2025
289dc9f
fix ruff
UlrickBL Oct 1, 2025
d1f5f07
fix ruff
UlrickBL Oct 1, 2025
8e1abbe
fix ruff
UlrickBL Oct 1, 2025
1971169
fix ruff
UlrickBL Oct 1, 2025
c543303
modify encode so it is robust to tests
UlrickBL Oct 1, 2025
c043abf
change test process chat format to adapt to new output
UlrickBL Oct 1, 2025
b1e917f
change test process chat format to adapt to new output
UlrickBL Oct 1, 2025
55ad406
change test process chat format to adapt to new output
UlrickBL Oct 1, 2025
8a37577
change test process chat format to adapt to new output
UlrickBL Oct 1, 2025
bbebab7
modif encoding to be robust to text only
UlrickBL Oct 1, 2025
649d87a
fix conflict with main
UlrickBL Oct 1, 2025
ae412a5
remove transformers from base dependencies
UlrickBL Oct 3, 2025
d2b6166
in grpo trainer, only convert images from PIL if not base64
UlrickBL Oct 3, 2025
02caa79
Create image utils for base64 and PIL transformation, move processing…
UlrickBL Oct 3, 2025
030dddc
add image utils
UlrickBL Oct 3, 2025
50eb136
add processor utils with lazy imports and no transformers to use in e…
UlrickBL Oct 3, 2025
653efb6
clean type checking to avoid transformers
UlrickBL Oct 3, 2025
3512a3b
fix issue with not callable Processor
UlrickBL Oct 3, 2025
491a8e4
fix ruff
UlrickBL Oct 3, 2025
6a32a10
fix py precommit checks for typing
UlrickBL Oct 3, 2025
235a452
fix py precommit checks for typing
UlrickBL Oct 3, 2025
555c268
fix py precommit checks for typing
UlrickBL Oct 3, 2025
80fb0bb
fix py precommit checks for typing
UlrickBL Oct 3, 2025
378f885
WIP : fix issue with validators in training by modifying GenerateOutp…
UlrickBL Oct 5, 2025
0fde4d0
fix ruff
UlrickBL Oct 5, 2025
193346c
WIP : working on Hindi OCR by fixing KL olds pixel values, but cannot…
UlrickBL Oct 6, 2025
ce432a6
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Oct 6, 2025
aea2b9d
correct trainer, almost working, fix pixel values and schuffle for di…
UlrickBL Oct 6, 2025
9b6f075
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Oct 6, 2025
1ed9ca6
fix grpo trainer
UlrickBL Oct 8, 2025
016cc35
clean debug and fix ruff
UlrickBL Oct 8, 2025
1323ba6
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Oct 8, 2025
fff88bf
clean ruff
UlrickBL Oct 8, 2025
13f939b
Merge branch 'main' into add_vlm_support
UlrickBL Oct 9, 2025
7817f25
Merge branch 'main' into add_vlm_support
UlrickBL Oct 11, 2025
d30bfdc
adapt code to multiple images in a single sequence
UlrickBL Oct 19, 2025
0ed1ff6
add behavior for multi images in prompt
UlrickBL Oct 22, 2025
e4f558f
Merge branch 'add_vlm_support' of https://github.com/UlrickBL/verifie…
UlrickBL Oct 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"requests",
"rich",
"textual",
"pillow>=10.0.0",
"pydantic>=2.11.9",
"prime-sandboxes>=0.1.0",
]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def apply_template(conversation, tokenize=False, add_generation_prompt=True):
(
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand Down Expand Up @@ -300,6 +302,8 @@ def test_process_completion_format(self, mock_openai_client, sample_dataset):
(
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand Down
67 changes: 47 additions & 20 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from typing import TYPE_CHECKING, Literal

from typing import TYPE_CHECKING, Literal, Union
from datasets import Dataset
from openai import AsyncOpenAI, BadRequestError, OpenAI

Expand All @@ -27,6 +26,7 @@
SamplingArgs,
State,
)
from verifiers.utils.processor_utils import encode_text_with_processor, encode_chat_with_processor
from verifiers.utils.message_utils import (
cleanup_messages,
get_overlong_prompt_dummy_response,
Expand All @@ -35,10 +35,9 @@

if TYPE_CHECKING:
from transformers.tokenization_utils_base import ( # type: ignore
PreTrainedTokenizerBase,
PreTrainedTokenizerBase, ProcessorMixin
)



class Environment(ABC):
"""
Base class for all environments.
Expand Down Expand Up @@ -69,7 +68,6 @@ def __init__(
self.logger.warning(
"The parser and rubric parser are different. This may cause unexpected behavior."
)

if self.message_type == "chat":
if dataset is not None:
self.dataset = self.format_dataset(
Expand Down Expand Up @@ -228,6 +226,7 @@ async def get_model_response(
):
sampling_args.pop("max_completion_tokens")
clean_sampling_args = {k: v for k, v in sampling_args.items() if v is not None}

try:
if message_type == "chat":
assert isinstance(prompt, list)
Expand Down Expand Up @@ -444,6 +443,7 @@ async def a_generate(
reward=[],
metrics={},
)

n = len(results.prompt)

# Resolve concurrency knobs
Expand Down Expand Up @@ -593,6 +593,7 @@ def generate(
) -> GenerateOutputs:
if isinstance(client, OpenAI):
client = AsyncOpenAI(api_key=client.api_key, base_url=client.base_url)

coro = self.a_generate(
inputs,
client,
Expand Down Expand Up @@ -819,9 +820,9 @@ def process_chat_format_vllm(
prompt: list[ChatMessage],
completion: list[ChatMessage],
state: State,
processing_class: "PreTrainedTokenizerBase",
processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"],
mask_env_responses: bool = False,
) -> tuple[list[int], list[int], list[int], list[int], list[float]]:
) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]:
"""
Process chat format conversations using incremental prefixes.
"""
Expand All @@ -836,10 +837,13 @@ def process_chat_format_vllm(
zipped.append((turn, None))
assert len(responses) == responses_idx, "Responses not fully consumed"
assert len(zipped) == len(completion), "Length mismatch"
prompt_ids: list[int] = processing_class.apply_chat_template(
conversation=prompt, # type: ignore

prompt_ids, prompt_image_grid, prompt_pixel_value = encode_chat_with_processor(
conversation=prompt,
processing_class=processing_class,
add_generation_prompt=True,
)

messages_consumed = [m for m in prompt]
prompt_mask: list[int] = [0] * len(prompt_ids)
completion_ids: list[int] = []
Expand Down Expand Up @@ -900,13 +904,15 @@ def deserialize_tool_call(tool_call) -> dict:
while j < len(zipped) and zipped[j][0]["role"] != "assistant":
consecutive_messages.append(zipped[j][0])
j += 1
token_prefix: list[int] = processing_class.apply_chat_template(
conversation=messages_consumed # type: ignore
token_prefix, token_prefix_image_grid, token_prefix_pixel_values = encode_chat_with_processor(
conversation=messages_consumed, # type: ignore
processing_class=processing_class,
add_generation_prompt=False,
)
token_prefix_with_turn: list[int] = (
processing_class.apply_chat_template(
conversation=messages_consumed + consecutive_messages, # type: ignore
)
token_prefix_with_turn, token_prefix_with_turn_image_grid,token_prefix_with_turn_pixel_values = encode_chat_with_processor(
conversation=messages_consumed + consecutive_messages, # type: ignore
processing_class=processing_class,
add_generation_prompt=False,
)
assert token_prefix_with_turn[: len(token_prefix)] == token_prefix, (
f"Token prefix mismatch. Token prefix: {token_prefix}, token prefix with turn: {token_prefix_with_turn}"
Expand All @@ -916,6 +922,7 @@ def deserialize_tool_call(tool_call) -> dict:
completion_turn_mask = [0] * len(completion_turn_ids)
else:
completion_turn_mask = [1] * len(completion_turn_ids)

completion_turn_logprobs = [0.0] * len(completion_turn_ids)
completion_ids.extend(completion_turn_ids)
completion_mask.extend(completion_turn_mask)
Expand All @@ -925,6 +932,8 @@ def deserialize_tool_call(tool_call) -> dict:
return (
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand All @@ -935,9 +944,9 @@ def process_completion_format_vllm(
prompt: str,
completion: str,
state: State,
processing_class: "PreTrainedTokenizerBase",
processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"],
mask_env_responses: bool = False,
) -> tuple[list[int], list[int], list[int], list[int], list[float]]:
) -> tuple[list[int], list[int], list[int], list[int], list[int], list[int], list[float]]:
"""
Process completion format conversations using incremental prefixes.
"""
Expand All @@ -958,12 +967,16 @@ def process_completion_format_vllm(
idx = response_start_idx + len(response_text)
assert idx == len(completion), "Completion not fully consumed"

prompt_ids: list[int] = processing_class.encode(prompt)
prompt_ids, prompt_image_grid, prompt_pixel_value = encode_text_with_processor(
text=prompt,
processing_class=processing_class,
)
rollout_consumed = prompt
prompt_mask: list[int] = [0] * len(prompt_ids)
completion_ids: list[int] = []
completion_mask: list[int] = []
completion_logprobs: list[float] = []

i = 0
while i < len(zipped):
text, response = zipped[i]
Expand Down Expand Up @@ -1000,6 +1013,8 @@ def process_completion_format_vllm(
return (
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand All @@ -1011,7 +1026,7 @@ def process_env_results_vllm(
completions: list[Messages],
states: list[State],
rewards: list[float],
processing_class: "PreTrainedTokenizerBase",
processing_class: Union["PreTrainedTokenizerBase", "ProcessorMixin"],
max_seq_len: int = -1,
mask_env_responses: bool = False,
mask_truncated_completions: bool = False,
Expand All @@ -1024,6 +1039,8 @@ def process_env_results_vllm(

all_prompt_ids = []
all_prompt_masks = []
all_prompt_image_grid = []
all_prompt_pixel_value = []
all_completion_ids = []
all_completion_masks = []
all_completion_logprobs = []
Expand All @@ -1037,6 +1054,8 @@ def process_env_results_vllm(
(
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand All @@ -1048,6 +1067,8 @@ def process_env_results_vllm(
(
prompt_ids,
prompt_mask,
prompt_image_grid,
prompt_pixel_value,
completion_ids,
completion_mask,
completion_logprobs,
Expand Down Expand Up @@ -1080,16 +1101,22 @@ def process_env_results_vllm(
)
all_prompt_ids.append(prompt_ids)
all_prompt_masks.append(prompt_mask)
all_prompt_image_grid.append(prompt_image_grid)
all_prompt_pixel_value.append(prompt_pixel_value)
all_completion_ids.append(completion_ids)
all_completion_masks.append(completion_mask)
all_completion_logprobs.append(completion_logprobs)

if zero_truncated_completions and is_truncated:
all_rewards.append(0)
else:
all_rewards.append(reward)

return ProcessedOutputs(
prompt_ids=all_prompt_ids,
prompt_mask=all_prompt_masks,
image_grid_thw = all_prompt_image_grid,
pixel_values= all_prompt_pixel_value,
completion_ids=all_completion_ids,
completion_mask=all_completion_masks,
completion_logprobs=all_completion_logprobs,
Expand Down
8 changes: 6 additions & 2 deletions verifiers/trainers/async_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import threading
import time
from collections import deque
from typing import Any
from typing import Any, Optional

from pydantic import BaseModel, Field

from verifiers import GenerateOutputs
from verifiers.types import ProcessedOutputs


class BatchRequest(BaseModel):
"""Request for batch generation"""

Expand All @@ -38,6 +37,7 @@ class BatchResult(BaseModel):
default_factory=list
) # Store completions for logging
prompts: list[Any] = Field(default_factory=list) # Store prompts for logging
answers : Optional[list[Any]]


class AsyncBatchGenerator:
Expand Down Expand Up @@ -264,6 +264,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult:
"""
# Call environment generation
self.is_generating = True

env_results = await self.env.a_generate(
request.env_inputs,
client=self.client,
Expand All @@ -272,6 +273,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult:
score_rollouts=True,
max_concurrent=request.max_concurrent,
)

self.is_generating = False

# Extract all reward-related keys
Expand All @@ -281,6 +283,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult:
for k in env_results.metrics:
all_reward_dict[k] = env_results.metrics[k]


# Process results
processed_results = self.env.process_env_results_vllm(
prompts=env_results.prompt,
Expand All @@ -300,6 +303,7 @@ async def _generate_batch_async(self, request: BatchRequest) -> BatchResult:
all_reward_dict=all_reward_dict,
completions=env_results.completion,
prompts=env_results.prompt,
answers=request.env_inputs.get("answer")
)

async def _evaluate_async(self, num_samples: int = -1) -> GenerateOutputs:
Expand Down
Loading