Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def simple_evaluate_user_model(
user_model,
tokenizer,
batch_size: Optional[int] = 1,
limit: Optional[Union[int, float]] = None,
max_batch_size: Optional[int] = 64,
eval_model_dtype="auto",
add_bos_token: bool = False,
Expand All @@ -40,14 +41,15 @@ def simple_evaluate_user_model(
add_bos_token=add_bos_token,
)
return lm_simple_evaluate(
model=hflm, model_args=None, batch_size=batch_size, max_batch_size=max_batch_size, **kwargs
model=hflm, model_args=None, batch_size=batch_size, max_batch_size=max_batch_size, limit=limit, **kwargs
)


def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
batch_size: Optional[int] = None,
limit: Optional[Union[int, float]] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
**kwargs
Expand All @@ -61,6 +63,7 @@ def simple_evaluate(
model=model,
model_args=model_args,
batch_size=batch_size,
limit=limit,
max_batch_size=max_batch_size,
device=device,
**kwargs
Expand Down
82 changes: 56 additions & 26 deletions auto_round/script/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def __init__(self, *args, **kwargs):

self.add_argument("--disable_act_dynamic", action="store_true", help="activation static quantization")

self.add_argument("--eval_bs", default=None, type=int, help="batch size in evaluation")

self.add_argument(
"--device_map",
"--device",
Expand Down Expand Up @@ -125,26 +123,10 @@ def __init__(self, *args, **kwargs):
help="scale data type to use for quantization",
)

self.add_argument(
"--tasks",
"--task",
nargs="?",
const="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1,"
"openbookqa,boolq,arc_easy,arc_challenge",
default=None,
help="lm-eval tasks",
)

self.add_argument(
"--output_dir", default="./tmp_autoround", type=str, help="the directory to save quantized model"
)

self.add_argument(
"--disable_eval", action="store_true", help="whether to disable lm-eval evaluation after tuning"
)

self.add_argument("--eval_task_by_task", action="store_true", help="whether to eval task by task.")

self.add_argument("--disable_amp", action="store_true", help="disable amp")

self.add_argument(
Expand Down Expand Up @@ -220,10 +202,6 @@ def __init__(self, *args, **kwargs):
"--disable_deterministic_algorithms", action="store_true", help="disable torch deterministic algorithms."
)

self.add_argument(
"--eval_model_dtype", default=None, type=str, help="the torch_dytpe to load the model for evaluation."
)

self.add_argument(
"--disable_opt_rtn",
action="store_true",
Expand Down Expand Up @@ -255,6 +233,38 @@ def __init__(self, *args, **kwargs):
help="the template for building training dataset. It can be a custom one.",
)

## ======================= eval =======================
self.add_argument(
"--disable_eval", action="store_true", help="whether to disable lm-eval evaluation after tuning"
)

self.add_argument(
"--tasks",
"--task",
nargs="?",
const="lambada_openai,hellaswag,winogrande,piqa,mmlu,wikitext,truthfulqa_mc1,"
"openbookqa,boolq,arc_easy,arc_challenge",
default=None,
help="lm-eval tasks",
)

self.add_argument("--eval_bs", default=None, type=int, help="batch size in evaluation")

self.add_argument(
"--limit",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)

self.add_argument("--eval_task_by_task", action="store_true", help="whether to eval task by task.")

self.add_argument(
"--eval_model_dtype", default=None, type=str, help="the torch_dytpe to load the model for evaluation."
)


class EvalArgumentParser(argparse.ArgumentParser):

Expand Down Expand Up @@ -292,6 +302,14 @@ def __init__(self, *args, **kwargs):
self.add_argument(
"--eval_model_dtype", default=None, type=str, help="the torch_dytpe to load the model for evaluation."
)
self.add_argument(
"--limit",
type=float,
default=None,
metavar="N|0<N<1",
help="Limit the number of examples per task. "
"If <1, limit is a percentage of the total number of examples.",
)


def setup_parser():
Expand Down Expand Up @@ -683,6 +701,7 @@ def tune(args):
tokenizer=tokenizer,
device=device_str,
tasks=args.tasks,
limit=args.limit,
batch_size=args.eval_bs,
eval_model_dtype=eval_model_dtype,
)
Expand All @@ -701,6 +720,7 @@ def tune(args):
tokenizer,
tasks=tasks,
batch_size=args.eval_bs,
limit=args.limit,
device=device_str,
eval_model_dtype=eval_model_dtype,
add_bos_token=add_bos_token,
Expand All @@ -714,6 +734,7 @@ def tune(args):
device=device_str,
tasks=args.tasks,
batch_size=args.eval_bs,
limit=args.limit,
eval_model_dtype=eval_model_dtype,
)
else:
Expand All @@ -726,7 +747,12 @@ def tune(args):
if "llama" in args.model.lower():
model_args += ",add_bos_token=True"
res = simple_evaluate(
model="hf", model_args=model_args, tasks=tasks, device=device_str, batch_size=args.eval_bs
model="hf",
model_args=model_args,
tasks=tasks,
device=device_str,
batch_size=args.eval_bs,
limit=args.limit,
)
print(make_table(res))
print("evaluation running time=%ds" % (time.time() - st))
Expand Down Expand Up @@ -788,7 +814,9 @@ def eval(args):
)
model.eval()
st = time.time()
res = simple_evaluate_user_model(model, tokenizer, tasks=tasks, batch_size=batch_size, device=device_str)
res = simple_evaluate_user_model(
model, tokenizer, tasks=tasks, batch_size=batch_size, device=device_str, limit=args.limit
)
print(make_table(res))
print("evaluation running time=%ds" % (time.time() - st))
else:
Expand All @@ -802,6 +830,7 @@ def eval(args):
tasks=tasks,
device=device_str,
batch_size=batch_size,
limit=args.limit,
)
from lm_eval.utils import make_table # pylint: disable=E0401

Expand All @@ -815,6 +844,7 @@ def eval_task_by_task(
tasks=None,
tokenizer=None,
batch_size=None,
limit=None,
max_batch_size=64,
trust_remote_code=True,
eval_model_dtype=None,
Expand Down Expand Up @@ -887,7 +917,7 @@ def eval_task_by_task(
while retry_times:
try:
res = lm_simple_evaluate(
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=batch_size, limit=limit
)
break
except Exception as e:
Expand All @@ -899,7 +929,7 @@ def eval_task_by_task(
hflm.batch_sizes[k] = max(v // 2, 1)
logger.warning(f"Out of memory, reset batch_size to {hflm.batch_sizes} and re-try.")
res = lm_simple_evaluate(
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=1
model=hflm, model_args=None, device=device_str, tasks=task, batch_size=1, limit=limit
)
hflm.batch_sizes = ori_batch_sizes
except Exception as e:
Expand Down
Loading