Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/lighteval/logging/info_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class GeneralConfigLogger:
model_size: str = None

generation_parameters: dict | None = None
chat_template_parameters: dict | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turbo nit, but perhaps it's better to call this chat_template_kwargs to align with vLLM and TRL?


# Nanotron config
config: "Config" = None
Expand Down Expand Up @@ -129,7 +130,9 @@ def log_args_info(
self.job_id = job_id
self.config = config

def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) -> None:
def log_model_info(
self, generation_parameters: dict, model_info: ModelInfo, chat_template_parameters: dict
) -> None:
"""
Logs the model information.

Expand All @@ -139,6 +142,7 @@ def log_model_info(self, generation_parameters: dict, model_info: ModelInfo) ->

"""
self.generation_parameters = generation_parameters
self.chat_template_parameters = chat_template_parameters
self.model_name = model_info.model_name
self.model_sha = model_info.model_sha
self.model_dtype = model_info.model_dtype
Expand Down
1 change: 1 addition & 0 deletions src/lighteval/main_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def baseline(
model_dtype=None,
model_size=None,
),
{},
)
evaluation_tracker.task_config_logger.log(tasks_dict)

Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/custom/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,4 @@ def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
An example of a custom model can be found in `examples/custom_models/google_translate_model.py`.
"""

model_name: str
model_definition_file_path: str
1 change: 0 additions & 1 deletion src/lighteval/models/endpoints/endpoint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ class ServerlessEndpointModelConfig(ModelConfig):
```
"""

model_name: str
add_special_tokens: bool = True
batch_size: int = 1

Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class LiteLLMModelConfig(ModelConfig):
```
"""

model_name: str
provider: str | None = None
base_url: str | None = None
api_key: str | None = None
Expand Down
15 changes: 15 additions & 0 deletions src/lighteval/models/model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,18 @@ def to_sglang_dict(self) -> dict:
"min_new_tokens": self.min_new_tokens,
}
return {k: v for k, v in args.items() if v is not None}


class ChatTemplateParameters(BaseModel):
reasoning_effort: str = None

def to_transformers_dict(self) -> dict:
"""Selects relevant chat template parameters for transformers models.

Returns:
dict: Valid parameters for the chat template
"""
args = {
"reasoning_effort": self.reasoning_effort,
}
return {k: v for k, v in args.items() if v is not None}
1 change: 0 additions & 1 deletion src/lighteval/models/sglang/sglang_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ class SGLangModelConfig(ModelConfig):
```
"""

model_name: str
load_format: str = "auto"
dtype: str = "auto"
tp_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism
Expand Down
6 changes: 4 additions & 2 deletions src/lighteval/models/transformers/transformers_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class TransformersModelConfig(ModelConfig):
(bitsandbytes for 4-bit/8-bit quantization).
"""

model_name: str
tokenizer: str | None = None
subfolder: str | None = None
revision: str = "main"
Expand Down Expand Up @@ -230,7 +229,10 @@ def __init__(
)

self.prompt_manager = PromptManager(
use_chat_template=self.use_chat_template, tokenizer=self.tokenizer, system_prompt=config.system_prompt
use_chat_template=self.use_chat_template,
tokenizer=self.tokenizer,
system_prompt=config.system_prompt,
chat_template_parameters=config.chat_template_parameters,
)

def cleanup(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class VLMTransformersModelConfig(ModelConfig):
loading.
"""

model_name: str
processor: str | None = None
use_fast_image_processor: bool | None = None
subfolder: str | None = None
Expand Down
21 changes: 17 additions & 4 deletions src/lighteval/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from transformers import AutoTokenizer
from transformers.models.auto.configuration_auto import AutoConfig

from lighteval.models.model_input import GenerationParameters
from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +70,7 @@ class ModelConfig(BaseModel, extra="forbid"):
config = ModelConfig.from_path("model_config.yaml")

# Load from command line arguments
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature=0.7}")
config = ModelConfig.from_args("model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt='You are a helpful assistant.',generation_parameters={temperature:0.7}")

# Direct instantiation
config = ModelConfig(
Expand All @@ -81,7 +81,9 @@ class ModelConfig(BaseModel, extra="forbid"):
```
"""

model_name: str
generation_parameters: GenerationParameters = GenerationParameters()
chat_template_parameters: ChatTemplateParameters = ChatTemplateParameters()
system_prompt: str | None = None

@classmethod
Expand Down Expand Up @@ -131,20 +133,31 @@ def _parse_args(args: str) -> dict:
"""
# Looking for generation_parameters in the model_args
generation_parameters_dict = None
pattern = re.compile(r"(\w+)=(\{.*\}|[^,]+)")
chat_template_parameters_dict = None
pattern = re.compile(r"(\w+)\s*=\s*(\{[^{}]*\}|[^,]+?)(?=,|$)")
matches = pattern.findall(args)
for key, value in matches:
key = key.strip()
if key == "generation_parameters":
gen_params = re.sub(r"(\w+):", r'"\1":', value)
generation_parameters_dict = json.loads(gen_params)
if key == "chat_template_parameters":
# Chat template parameters have strings as values that also need to be quoted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be relevant to have tests for parsing edge cases

chat_template_params = re.sub(r"(\w+)\s*:\s*([A-Za-z_][\w.-]*)\s*(?=[,}])", r'"\1":"\2"', value)
chat_template_parameters_dict = json.loads(chat_template_params)

args = re.sub(r"generation_parameters=\{.*?\},?", "", args).strip(",")
model_config = {k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")}
args = re.sub(r"chat_template_parameters=\{.*?\},?", "", args).strip(",")
model_config = (
{k.split("=")[0]: k.split("=")[1] if "=" in k else True for k in args.split(",")} if args else {}
)

if generation_parameters_dict is not None:
model_config["generation_parameters"] = generation_parameters_dict

if chat_template_parameters_dict is not None:
model_config["chat_template_parameters"] = chat_template_parameters_dict

return model_config


Expand Down
1 change: 0 additions & 1 deletion src/lighteval/models/vllm/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class VLLMModelConfig(ModelConfig):
```
"""

model_name: str
revision: str = "main" # revision of the model
dtype: str = "bfloat16"
tensor_parallel_size: PositiveInt = 1 # how many GPUs to use for tensor parallelism
Expand Down
5 changes: 4 additions & 1 deletion src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,11 @@ def __init__(
self.model = self._init_model(model_config, model)

generation_parameters = model_config.generation_parameters.model_dump() if model_config else {}
chat_template_parameters = model_config.chat_template_parameters.model_dump() if model_config else {}

self.evaluation_tracker.general_config_logger.log_model_info(generation_parameters, self.model.model_info)
self.evaluation_tracker.general_config_logger.log_model_info(
generation_parameters, self.model.model_info, chat_template_parameters
)

self._init_random_seeds()
self._init_tasks_and_requests(tasks=tasks)
Expand Down
11 changes: 10 additions & 1 deletion src/lighteval/tasks/prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from itertools import cycle
from typing import TYPE_CHECKING

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.requests import Doc
from lighteval.utils.utils import as_list

Expand All @@ -40,10 +41,17 @@


class PromptManager:
def __init__(self, use_chat_template: bool = False, tokenizer=None, system_prompt: str | None = None):
def __init__(
self,
use_chat_template: bool = False,
tokenizer=None,
system_prompt: str | None = None,
chat_template_parameters: ChatTemplateParameters | None = None,
):
self.use_chat_template = use_chat_template
self.tokenizer = tokenizer
self.system_prompt = system_prompt # System prompt to be used in chat templates
self.chat_template_parameters = chat_template_parameters if chat_template_parameters else {}

def prepare_prompt(self, doc: Doc) -> str:
"""Prepare a prompt from a document, either using chat template or plain text format."""
Expand Down Expand Up @@ -123,6 +131,7 @@ def _prepare_chat_template(self, doc: Doc, tokenize: bool = True) -> str:
messages,
tokenize=False,
add_generation_prompt=True,
**self.chat_template_parameters.to_transformers_dict(),
)

else: # for apis
Expand Down
17 changes: 17 additions & 0 deletions tests/test_prompt_manager_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pytest

from lighteval.models.model_input import ChatTemplateParameters
from lighteval.tasks.prompt_manager import PromptManager
from lighteval.tasks.requests import Doc

Expand All @@ -47,6 +48,22 @@ def test_init_with_chat_template(self):
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt

def test_init_with_chat_template_and_chat_template_parameters(self):
"""Test PromptManager initialization with chat template enabled and chat template parameters."""
tokenizer = Mock()
system_prompt = "You are a helpful assistant."
pm = PromptManager(
use_chat_template=True,
tokenizer=tokenizer,
system_prompt=system_prompt,
chat_template_parameters=ChatTemplateParameters(reasoning_effort="medium"),
)
assert pm.use_chat_template is True
assert pm.tokenizer == tokenizer
assert pm.system_prompt == system_prompt
assert pm.chat_template_parameters is not None
assert pm.chat_template_parameters.reasoning_effort == "medium"

def test_prepare_prompt_plain_text_basic(self):
"""Test prepare_prompt with plain text format and basic document."""
pm = PromptManager()
Expand Down
84 changes: 84 additions & 0 deletions tests/utils/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import unittest

from lighteval.models.model_input import ChatTemplateParameters, GenerationParameters
from lighteval.models.utils import ModelConfig


class TestModelConfig(unittest.TestCase):
def test_model_config_init(self):
config = ModelConfig(
model_name="meta-llama/Llama-3.1-8B-Instruct",
generation_parameters=GenerationParameters(temperature=0.7),
system_prompt="You are a helpful assistant.",
chat_template_parameters=ChatTemplateParameters(reasoning_effort="low"),
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, "You are a helpful assistant.")
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_init_command_line(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,system_prompt="You are a helpful assistant.",generation_parameters={temperature:0.7},chat_template_parameters={reasoning_effort:low}'
)

self.assertEqual(config.model_name, "meta-llama/Llama-3.1-8B-Instruct")
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.system_prompt, '"You are a helpful assistant."') # is this what we want?
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")

def test_model_config_generation_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

def test_model_config_generation_parameters_parse_multiple_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={temperature:0.7,top_k:42}"
)
self.assertEqual(config.generation_parameters.temperature, 0.7)
self.assertEqual(config.generation_parameters.top_k, 42)

@unittest.skip("This is not working at this time")
def test_model_config_generation_parameters_parse_string(self):
config = ModelConfig.from_args(
'model_name=meta-llama/Llama-3.1-8B-Instruct,generation_parameters={response_format:{"type":"json_object"}}'
)
self.assertEqual(config.generation_parameters.temperature, 0.7)

@unittest.skip("This is not working at this time")
def test_model_config_chat_template_parameters_parse_single_int(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={temperature:0.7}"
)
self.assertEqual(config.chat_template_parameters.temperature, 0.7)

def test_model_config_chat_template_parameters_parse_string(self):
config = ModelConfig.from_args(
"model_name=meta-llama/Llama-3.1-8B-Instruct,chat_template_parameters={reasoning_effort:low}"
)
self.assertEqual(config.chat_template_parameters.reasoning_effort, "low")