Skip to content

Commit e68a0bf

Browse files
Moved custom_yaml.py in utils and other appropriate changes
1 parent c8fa8cd commit e68a0bf

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

QEfficient/cloud/export.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,23 @@
1111

1212
from QEfficient.base.common import QEFFCommonLoader
1313
from QEfficient.utils import check_and_assign_cache_dir
14+
from QEfficient.utils.custom_yaml import generate_custom_io
1415
from QEfficient.utils.logging_utils import logger
1516

16-
from .custom_yaml import generate_custom_io
17-
1817
# Specifically for Docker images.
1918
ROOT_DIR = os.path.dirname(os.path.abspath(""))
2019

2120

22-
def get_onnx_model_path(
21+
def get_onnx_path_and_setup_customIO(
2322
model_name: str,
2423
cache_dir: Optional[str] = None,
2524
hf_token: Optional[str] = None,
2625
full_batch_size: Optional[int] = None,
2726
local_model_dir: Optional[str] = None,
27+
mxint8_kv_cache: Optional[int] = False,
2828
):
2929
"""
30-
exports the model to onnx if pre-exported file is not found and returns onnx_model_path
30+
exports the model to onnx if pre-exported file is not found and returns onnx_model_path and generates cutom_io file.
3131
3232
``Mandatory`` Args:
3333
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``.
@@ -47,9 +47,11 @@ def get_onnx_model_path(
4747
full_batch_size=full_batch_size,
4848
local_model_dir=local_model_dir,
4949
)
50-
generate_custom_io(qeff_model, cache_dir=".", mxint8_kv_cache=False)
5150
onnx_model_path = qeff_model.export()
5251
logger.info(f"Generated onnx_path: {onnx_model_path}")
52+
53+
# Generating Custom IO for the compile.
54+
generate_custom_io(qeff_model, mxint8_kv_cache=mxint8_kv_cache)
5355
return onnx_model_path
5456

5557

@@ -59,6 +61,7 @@ def main(
5961
hf_token: Optional[str] = None,
6062
local_model_dir: Optional[str] = None,
6163
full_batch_size: Optional[int] = None,
64+
mxint8_kv_cache: Optional[bool] = False,
6265
) -> None:
6366
"""
6467
Helper function used by export CLI app for exporting to ONNX Model.
@@ -71,19 +74,20 @@ def main(
7174
:hf_token (str): HuggingFace login token to access private repos. ``Defaults to None.``
7275
:local_model_dir (str): Path to custom model weights and config files. ``Defaults to None.``
7376
:full_batch_size (int): Set full batch size to enable continuous batching mode. ``Defaults to None.``
74-
77+
:mxint8_kv_cache (bool): Whether to export int8 model or not. ``Defaults to False.``
7578
.. code-block:: bash
7679
7780
python -m QEfficient.cloud.export OPTIONS
7881
7982
"""
8083
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
81-
get_onnx_model_path(
84+
get_onnx_path_and_setup_customIO(
8285
model_name=model_name,
8386
cache_dir=cache_dir,
8487
hf_token=hf_token,
8588
full_batch_size=full_batch_size,
8689
local_model_dir=local_model_dir,
90+
mxint8_kv_cache=mxint8_kv_cache,
8791
)
8892

8993

@@ -109,5 +113,11 @@ def main(
109113
default=None,
110114
help="Set full batch size to enable continuous batching mode, default is None",
111115
)
116+
parser.add_argument(
117+
"--mxint8_kv_cache",
118+
"--mxint8-kv-cache",
119+
required=False,
120+
help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False",
121+
)
112122
args = parser.parse_args()
113123
main(**args.__dict__)

QEfficient/cloud/custom_yaml.py renamed to QEfficient/utils/custom_yaml.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
18
import warnings
29
from pathlib import Path
310

@@ -94,6 +101,7 @@ def generate(self) -> dict:
94101

95102
self.dump(custom_io_vision, f"{self.dtype_suffix}_vision")
96103
self.dump(custom_io_lang, f"{self.dtype_suffix}_lang")
104+
warnings.warn(f"Unsupported model class via CLI: {type(self.model).__name__}", UserWarning)
97105
return {**custom_io_vision, **custom_io_lang}
98106

99107

0 commit comments

Comments
 (0)