Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,24 @@ def __init__(self, *args, **kwargs):
"Options: 'float16', 'bfloat16', 'float32'. "
"Should match your hardware capabilities for best performance.",
)
eval_args.add_argument(
"--task_configs",
type=str,
default=None,
help=(
"Optional per-task configuration in JSON or simplified format. "
"Example JSON: "
'\'{"gsm8k_llama": {"apply_chat_template": true, "fewshot_as_multiturn": true}, '
' "hellaswag": {"num_fewshot": 10}}\' '
"You can also provide a JSON file path like 'task_configs.json'."
),
)
eval_args.add_argument(
"--disable_thinking",
action="store_true",
help=("whether to disable thinking mode of chat_template."),
)
eval_args.add_argument("--max_length", default=None, type=int, help="Random seed for reproducibility.")

## ======================= MLLM =======================
mllm_args = self.add_argument_group("Multimodal Large Language Model(MLLM) arguments")
Expand Down Expand Up @@ -735,6 +753,9 @@ def tune(args):
limit=args.limit,
batch_size=args.eval_bs,
eval_model_dtype=eval_model_dtype,
task_configs=args.task_configs,
disable_thinking=args.disable_thinking,
max_length=args.max_length,
)
else:
if args.eval_bs is None or args.eval_bs == "auto":
Expand Down Expand Up @@ -763,11 +784,15 @@ def tune(args):
eval_task_by_task(
eval_folder,
device=device_str,
tokenizer=tokenizer,
tasks=args.tasks,
batch_size=args.eval_bs,
limit=args.limit,
eval_model_dtype=eval_model_dtype,
mllm=autoround.mllm, # pylint: disable=E1101
task_configs=args.task_configs,
disable_thinking=args.disable_thinking,
max_length=args.max_length,
)
else:
from auto_round.eval.evaluation import simple_evaluate
Expand Down Expand Up @@ -821,6 +846,9 @@ def run_eval():
batch_size=args.eval_bs,
trust_remote_code=not args.disable_trust_remote_code,
eval_model_dtype=args.eval_model_dtype,
task_configs=args.task_configs,
disable_thinking=args.disable_thinking,
max_length=args.max_length,
)
else:
eval(args)
Expand Down
95 changes: 93 additions & 2 deletions auto_round/eval/eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import time

Expand Down Expand Up @@ -101,6 +102,24 @@ def __init__(self, *args, **kwargs):
choices=["hf", "vllm"],
help="Backend to use for model evaluation. Use hf backend for evaluation by default.",
)
self.add_argument(
"--task_configs",
type=str,
default=None,
help=(
"Optional per-task configuration in JSON or simplified format. "
"Example JSON: "
'\'{"gsm8k_llama": {"apply_chat_template": true, "fewshot_as_multiturn": true}, '
' "hellaswag": {"num_fewshot": 10}}\' '
"You can also provide a JSON file path like 'task_configs.json'."
),
)
self.add_argument(
"--disable_thinking",
action="store_true",
help=("whether to disable thinking mode of chat_template."),
)
self.add_argument("--max_length", default=None, type=int, help="max generation length for eval")

# vllm related arguments
vllm_args = self.add_argument_group("vllm backend arguments")
Expand Down Expand Up @@ -221,7 +240,34 @@ def eval_task_by_task(
eval_model_dtype=None,
retry_times=3,
mllm=False,
task_configs=None, # e.g. {"gsm8k": {"apply_chat_template": True, "fewshot_as_multiturn": True}}
disable_thinking=False,
max_length=None, # default to align with model's original setting
):
"""
Evaluate each LM-eval task sequentially, with optional per-task overrides.

Args:
model (str | nn.Module): Model path or loaded model.
device (str): Device id (e.g. "0" or "cuda:0").
tasks (list[str] | str): Tasks to run, separated by comma.
tokenizer: HuggingFace tokenizer.
batch_size: Eval batch size (default: "auto:8").
limit: Number of samples or fraction per task.
task_configs (dict): Optional task-specific settings like fewshot/chat.
"""
if isinstance(task_configs, str):
if os.path.isfile(task_configs):
with open(task_configs, "r") as f:
task_configs = json.load(f)
else:
try:
task_configs = json.loads(task_configs)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid --task_configs format: {e}")
elif task_configs is None:
task_configs = {}

set_cuda_visible_devices(device)
device_str, parallelism = get_device_and_parallelism(device)

Expand All @@ -237,6 +283,10 @@ def eval_task_by_task(

if batch_size is None:
batch_size = "auto:8"

# -------------------------------
# Load model (support gguf)
# -------------------------------
is_gguf_file = False
if not isinstance(model, str):
parallelism = False
Expand Down Expand Up @@ -265,6 +315,19 @@ def eval_task_by_task(
)
model.eval()
parallelism = False

# -------------------------------
# Build LM-eval model wrapper
# -------------------------------
if disable_thinking: ## align with fp-quant
from functools import partial

tokenizer.apply_chat_template = partial(tokenizer.apply_chat_template, enable_thinking=False)
# check the max_length
init_kwargs = {}
if max_length is not None:
init_kwargs["max_length"] = max_length

if mllm:
if batch_size is None or batch_size == "auto":
logger.warning("hf-multimodal models does not support auto currently, reset eval_bs to 16")
Expand All @@ -278,6 +341,7 @@ def eval_task_by_task(
parallelize=parallelism,
trust_remote_code=trust_remote_code,
dtype=eval_model_dtype,
**init_kwargs,
)
else:
hflm = HFLM(
Expand All @@ -289,6 +353,7 @@ def eval_task_by_task(
parallelize=parallelism,
trust_remote_code=trust_remote_code,
dtype=eval_model_dtype,
**init_kwargs,
)

if isinstance(tasks, str):
Expand All @@ -302,10 +367,28 @@ def eval_task_by_task(

st = time.time()
for task in tasks:
task_cfg = task_configs.get(task, {})
num_fewshot = task_cfg.get("num_fewshot")
apply_chat_template = task_cfg.get("apply_chat_template", False)
batch_size = task_cfg.get("batch_size", batch_size)
fewshot_as_multiturn = task_cfg.get("fewshot_as_multiturn", False)
logger.info(f"=== Running task: {task} ===")
logger.info(
f"Task config: fewshot={num_fewshot}, apply_chat_template={apply_chat_template},"
f"fewshot_as_multiturn={fewshot_as_multiturn}, batch_size={batch_size}"
)
while retry_times:
try:
res = lm_simple_evaluate(
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size, limit=limit
model=hflm,
model_args=None,
device=device_str,
tasks=task,
batch_size=batch_size,
limit=limit,
num_fewshot=num_fewshot,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
)
break
except Exception as e:
Expand All @@ -317,7 +400,15 @@ def eval_task_by_task(
hflm.batch_sizes[k] = max(v // 2, 1)
logger.warning(f"Out of memory, reset batch_size to {hflm.batch_sizes} and re-try.")
res = lm_simple_evaluate(
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=1, limit=limit
model=hflm,
model_args=None,
device=device_str,
tasks=task,
batch_size=1,
limit=limit,
num_fewshot=num_fewshot,
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
)
hflm.batch_sizes = ori_batch_sizes
except Exception as e:
Expand Down
Loading