From 6530d733e1cc4858cd7a595f128bb2f3eb1d242d Mon Sep 17 00:00:00 2001 From: isamu-isozaki Date: Fri, 28 Jun 2024 22:59:05 -0400 Subject: [PATCH 1/3] Exllamav2_filter Fix comment Fixed precommit issues Removed text Basic draft done Passed local test Fixed tests+precommit Revert change for pyairports Fixed precommit Wrap up Remove | for union Attempt changing to List Fixed for 3.8 Adding exllamav2 to optional dependency Fixed model Changed to fork Fix format Changed order Skip exllamav2 tests Attempt fixing coverage Attempt fix coverage Remove flash-attn requirement Fixed fixture tests Removed lora Passed coverage Added back transformers install Fixed per review Made coverage 100% --- outlines/generate/fsm.py | 13 +- outlines/generate/regex.py | 18 +- outlines/generate/text.py | 11 +- outlines/models/exllamav2.py | 410 ++++++++++++------- pyproject.toml | 1 + tests/generate/conftest.py | 8 +- tests/generate/test_generate.py | 10 + tests/generate/test_integration_exllamav2.py | 363 ++++++++++++++++ 8 files changed, 641 insertions(+), 193 deletions(-) create mode 100644 tests/generate/test_integration_exllamav2.py diff --git a/outlines/generate/fsm.py b/outlines/generate/fsm.py index a9338836a..1950812d2 100644 --- a/outlines/generate/fsm.py +++ b/outlines/generate/fsm.py @@ -4,11 +4,10 @@ from outlines.fsm.guide import RegexGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, TransformersVision +from outlines.models import TransformersVision from outlines.samplers import Sampler, multinomial @@ -30,13 +29,3 @@ def fsm_vision(model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() guide = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) logits_processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) - - -@fsm.register(ExLlamaV2Model) -def fsm_exllamav2( - model, fsm: interegular.fsm.FSM, sampler: Sampler = multinomial() -) -> SequenceGenerator: - fsm = RegexGuide.from_interegular_fsm(fsm, model.tokenizer) - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - return generator diff --git a/outlines/generate/regex.py b/outlines/generate/regex.py index 815a8b1b9..673880e49 100644 --- a/outlines/generate/regex.py +++ b/outlines/generate/regex.py @@ -1,12 +1,10 @@ from functools import singledispatch -from outlines.fsm.guide import RegexGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -49,20 +47,6 @@ def regex_vision( return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) -@regex.register(ExLlamaV2Model) -def regex_exllamav2( - model, - regex_str: str, - sampler: Sampler = multinomial(), -) -> SequenceGenerator: - fsm = RegexGuide(regex_str, model.tokenizer) - - device = model.device - generator = SequenceGenerator(fsm, model, sampler, device) - - return generator - - @regex.register(OpenAI) def regex_openai( model: OpenAI, diff --git a/outlines/generate/text.py b/outlines/generate/text.py index 3fe3dc553..32530d0c4 100644 --- a/outlines/generate/text.py +++ b/outlines/generate/text.py @@ -1,12 +1,10 @@ from functools import singledispatch -from outlines.fsm.guide import StopAtEOSGuide from outlines.generate.api import ( - SequenceGenerator, SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, OpenAI, TransformersVision +from outlines.models import OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -36,13 +34,6 @@ def text(model, sampler: Sampler = multinomial()) -> SequenceGeneratorAdapter: return SequenceGeneratorAdapter(model, None, sampler) -@text.register(ExLlamaV2Model) -def text_exllamav2(model, sampler: Sampler = multinomial()) -> SequenceGenerator: - fsm = StopAtEOSGuide(model.tokenizer) - device = model.device - return SequenceGenerator(fsm, model, sampler, device) - - @text.register(TransformersVision) def text_vision(model, sampler: Sampler = multinomial()): return VisionSequenceGeneratorAdapter(model, None, sampler) diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index 0ec6ef033..f06b7e46e 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,12 +1,21 @@ -import os -from typing import TYPE_CHECKING, Optional +import dataclasses +from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union + +from typing_extensions import Unpack + +from outlines.generate.api import GenerationParameters, SamplingParameters if TYPE_CHECKING: - from exllamav2 import ExLlamaV2, ExLlamaV2Cache, ExLlamaV2Lora - from transformers import PreTrainedTokenizer - import torch + from exllamav2 import ExLlamaV2Tokenizer + from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler + -from .transformers import TransformerTokenizer +class ExllamaV2Params(TypedDict, total=False): + max_tokens: int + stop_conditions: Optional[List[Union[int, str]]] + seed: Optional[int] + gen_settings: "ExLlamaV2Sampler.Settings" + max_new_tokens: List[int] class ExLlamaV2Model: @@ -14,108 +23,218 @@ class ExLlamaV2Model: def __init__( self, - model: "ExLlamaV2", - tokenizer: "PreTrainedTokenizer", - device, - cache: "ExLlamaV2Cache", - lora: Optional["ExLlamaV2Lora"] = None, + generator: "ExLlamaV2DynamicGenerator", + tokenizer: "ExLlamaV2Tokenizer", + max_seq_len: int, ): - self.device = device - self.model = model - self.tokenizer = TransformerTokenizer(tokenizer) - self.cache = cache - self.past_seq = None - self.lora = lora - - def forward(self, input_ids: "torch.LongTensor", *_): - """Compute a forward pass through the exl2 model.""" - import torch - - # Caching with past_seq - reset = True - seq_tensor = input_ids[0] - - if self.past_seq is not None: - min_length = min(self.past_seq.shape[0], seq_tensor.shape[0]) - indices = torch.nonzero( - ~torch.eq(self.past_seq[:min_length], seq_tensor[:min_length]) - ) - if len(indices) > 0: - longest_prefix = indices[0].item() - else: - longest_prefix = min_length - - if longest_prefix > 0: - reset = False - self.cache.current_seq_len = longest_prefix - if seq_tensor.shape[0] - longest_prefix > 1: - self.model.forward( - seq_tensor[longest_prefix:-1].view(1, -1), - self.cache, - preprocess_only=True, - loras=[self.lora], - ) - elif seq_tensor.shape[0] == longest_prefix: - self.cache.current_seq_len -= 1 - - if reset: - self.cache.current_seq_len = 0 - if seq_tensor.shape[0] > 1: - self.model.forward( - seq_tensor[:-1].view(1, -1), - self.cache, - preprocess_only=True, - loras=[self.lora], + self.generator = generator + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + + def prepare_generation_parameters( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + sampling_parameters: SamplingParameters, + structure_logits_processor, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Tuple[ExllamaV2Params, Union[str, List[str]]]: + """Prepare the generation parameters. + + `exllamav2` uses different default values + + """ + from exllamav2.generator import ExLlamaV2Sampler + + if isinstance(prompts, str): + prompts = [prompts] + max_tokens, stop_at, seed = dataclasses.astuple(generation_parameters) + + if max_tokens is None: + max_tokens = [] + for prompt in prompts: + ids = self.generator.tokenizer.encode( + prompt, encode_special_tokens=True ) + prompt_tokens = ids.shape[-1] + max_tokens.append(self.max_seq_len - prompt_tokens) + exllamav2_params["max_new_tokens"] = max_tokens + else: + exllamav2_params["max_new_tokens"] = [ + max_tokens for _ in range(len(prompts)) + ] - self.past_seq = seq_tensor + stop_conditions = [self.generator.tokenizer.eos_token_id] + if isinstance(generation_parameters.stop_at, str): + stop_conditions.append(generation_parameters.stop_at) + elif isinstance(generation_parameters.stop_at, list): + for stop_at in generation_parameters.stop_at: + stop_conditions.append(stop_at) + exllamav2_params["stop_conditions"] = stop_conditions + exllamav2_params["seed"] = seed - return self.model.forward( - seq_tensor[-1:].view(1, -1), self.cache, loras=[self.lora] - ) + gen_settings = ExLlamaV2Sampler.Settings() + if sampling_parameters.temperature is not None: + gen_settings.temperature = sampling_parameters.temperature + if sampling_parameters.top_p is not None: + gen_settings.top_p = sampling_parameters.top_p + if sampling_parameters.top_k is not None: + gen_settings.top_k = sampling_parameters.top_k + gen_settings.logits_processor = structure_logits_processor + exllamav2_params["gen_settings"] = gen_settings + if sampling_parameters.num_samples > 1: + prompts = prompts * sampling_parameters.num_samples + exllamav2_params["max_new_tokens"] = ( + exllamav2_params["max_new_tokens"] * sampling_parameters.num_samples + ) - def __call__(self, input_ids: "torch.LongTensor", *_) -> "torch.FloatTensor": - logits = self.forward(input_ids) - next_token_logits = logits[..., -1, :] + if len(prompts) == 1: + prompts = prompts[0] - return next_token_logits, None + return exllamav2_params, prompts - def update_lora(self, lora_path: Optional[str] = None): + def reformat_output( + self, output: Union[str, List[str]], sampling_parameters: SamplingParameters + ): """ - Update and apply the LoRA to the model. + The purpose of this function is to reformat the output from exllamav2's output format to outline's output format + For exllamav2, it mainly accepts only a list or a string(they also do cfg sampling with tuples but we will ignore this for now) + The exllamav2's logic is + 1. If the prompt is a string, return a string. This is the same as outlines + 2. If a prompt is a list, return a list. This is not the same as outlines output in that if the list is only one element, the string is expected to be outputted. + 3. There is no such thing as num_samples, so the prompts had to be duplicated by num_samples times. Then, we had the function output a list of lists + """ + if isinstance(output, str): + return output + if len(output) == 1: + return output[0] + if sampling_parameters.num_samples > 1: + if len(output) == sampling_parameters.num_samples: + return output + assert len(output) % sampling_parameters.num_samples == 0 + num_items_per_sample = len(output) // sampling_parameters.num_samples + new_output = [] + for i in range(sampling_parameters.num_samples): + curr_sample = [] + for j in range(num_items_per_sample): + curr_sample.append(output[i * num_items_per_sample + j]) + new_output.append(curr_sample) + return new_output + return output - Args: - lora_path (Optional[str]): The path to the LoRA directory. If None, the LoRA will be unloaded. + def generate( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Union[str, List[str]]: + exllamav2_params, prompts = self.prepare_generation_parameters( + prompts, + generation_parameters, + sampling_parameters, + structure_logits_processor, + ) """ - try: - from exllamav2 import ExLlamaV2Lora - except ImportError: - raise ImportError( - "The `exllamav2` library needs to be installed in order to use `exllamav2` models." + In exllamav2, it needs the max amount of new tokens generated. + The reason exllamav2_params["max_new_tokens"] is a list is because in prepare_generation_parameters + the max amount of tokens that can be generated by the model for each prompt(by encoding with tokenizer) is calculated. + The minimum is picked because otherwise it might be possible for one of the + prompts to exceed the max sequence length. + """ + output = self.generator.generate( + prompt=prompts, + gen_settings=exllamav2_params["gen_settings"], + max_new_tokens=min(exllamav2_params["max_new_tokens"]), + completion_only=True, + encode_special_tokens=True, + stop_conditions=exllamav2_params["stop_conditions"], + add_bos=False, + seed=exllamav2_params["seed"], + ) + + return self.reformat_output(output, sampling_parameters) + + def stream( + self, + prompts: Union[str, List[str]], + generation_parameters: GenerationParameters, + structure_logits_processor, + sampling_parameters: SamplingParameters, + **exllamav2_params: Unpack[ExllamaV2Params], + ) -> Iterator[Union[str, List[str]]]: + from exllamav2.generator import ExLlamaV2DynamicJob + + exllamav2_params, prompts = self.prepare_generation_parameters( + prompts, + generation_parameters, + sampling_parameters, + structure_logits_processor, + ) + + order = {} + if isinstance(prompts, str): + prompts = [prompts] + batch_size = len(prompts) + seed = exllamav2_params["seed"] + for idx, p in enumerate(prompts): + input_ids = self.generator.tokenizer.encode( + p, encode_special_tokens=True, add_bos=False ) - if lora_path is None: - if self.lora is not None: - print(" -- Unloading LoRA...") - self.lora = None - else: - self.lora = ExLlamaV2Lora.from_directory(self.model, lora_path) - print(" -- Loading LoRA...") + + job = ExLlamaV2DynamicJob( + input_ids=input_ids, + max_new_tokens=exllamav2_params["max_new_tokens"][idx], + min_new_tokens=0, + seed=seed, + stop_conditions=exllamav2_params["stop_conditions"], + gen_settings=exllamav2_params["gen_settings"], + token_healing=False, + decode_special_tokens=False, + ) + + if seed is not None: + seed += 1 + + serial = self.generator.enqueue(job) + order[serial] = idx + + # Collect outputs until all jobs finish + + next_text = [""] * batch_size + + def token_generator() -> Iterator[str]: + while self.generator.num_remaining_jobs(): + results = self.generator.iterate() + for r in results: + idx = order[r["serial"]] + if r["stage"] == "streaming": + text = r.get("text", "") + next_text[idx] = text + if r["eos"]: + next_text[idx] = "" + yield self.reformat_output(next_text, sampling_parameters) + return + + return token_generator() + + +# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b +def patch_tokenizer(tokenizer): + tokenizer.vocabulary = tokenizer.piece_to_id + tokenizer.special_tokens = set(tokenizer.extended_piece_to_id) + tokenizer.convert_token_to_string = lambda t: t + return tokenizer def exl2( model_path: str, - device: str, + draft_model_path: Optional[str] = None, max_seq_len: Optional[int] = None, - scale_pos_emb: Optional[float] = None, - scale_alpha_value: Optional[float] = None, - no_flash_attn: Optional[bool] = None, - num_experts_per_token: Optional[int] = None, - cache_8bit: bool = False, cache_q4: bool = False, - tokenizer_kwargs: dict = {}, - gpu_split: Optional[str] = None, - low_mem: Optional[bool] = None, - verbose: Optional[bool] = None, + paged: bool = True, + max_chunk_size: Optional[int] = None, ) -> ExLlamaV2Model: """ Load an ExLlamaV2 model. @@ -136,8 +255,6 @@ def exl2( Disable flash attention. Defaults to None. num_experts_per_token (Optional[int], optional) Number of experts per token. Defaults to None. - cache_8bit (bool, optional) - Use 8-bit cache. Defaults to False. cache_q4 (bool, optional) Use Q4 cache. Defaults to False. tokenizer_kwargs (dict, optional) @@ -162,71 +279,62 @@ def exl2( from exllamav2 import ( ExLlamaV2, ExLlamaV2Cache, - ExLlamaV2Cache_8bit, ExLlamaV2Cache_Q4, ExLlamaV2Config, + ExLlamaV2Tokenizer, ) - from transformers import AutoTokenizer + from exllamav2.generator import ExLlamaV2DynamicGenerator + except ImportError: raise ImportError( "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models." ) + config = ExLlamaV2Config(model_path) + if max_chunk_size is not None: + config.max_input_len = max_chunk_size + config.max_attention_size = max_chunk_size**2 - # Load tokenizer - if not verbose: - print(" -- Loading tokenizer...") - tokenizer_kwargs.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs) - # tokenizer = TransformerTokenizer(model_path, **tokenizer_kwargs) - - # Check fasttensors for config - if os.name != "nt": - use_fasttensors = True - else: - use_fasttensors = False - - # Create config - config = ExLlamaV2Config() - config.model_dir = model_path - config.fasttensors = use_fasttensors - config.prepare() - - # Set config options - if max_seq_len is not None: - config.max_seq_len = max_seq_len - if scale_pos_emb is not None: - config.scale_pos_emb = scale_pos_emb - if scale_alpha_value is not None: - config.scale_alpha_value = scale_alpha_value - if no_flash_attn is not None: - config.no_flash_attn = no_flash_attn - if num_experts_per_token is not None: - config.num_experts_per_token = num_experts_per_token - if low_mem: - config.set_low_mem() - - # Prepare the model from the config + config.arch_compat_overrides() model = ExLlamaV2(config) - - # Create cache - if cache_8bit: - cache = ExLlamaV2Cache_8bit(model, lazy=not model.loaded) - elif cache_q4: - cache = ExLlamaV2Cache_Q4(model, lazy=not model.loaded) + if max_seq_len is None: + max_seq_len = -1 + if cache_q4: + cache = ExLlamaV2Cache_Q4(model, max_seq_len=max_seq_len, lazy=True) else: - cache = ExLlamaV2Cache(model, lazy=not model.loaded) - - # Load the model - split = None - if gpu_split and gpu_split != "auto": - split = [float(alloc) for alloc in gpu_split.split(",")] - if not verbose: - print(" -- Loading model...") - model.load(split) - - # Autoload if no GPU split was provided - if not model.loaded: - print(" -- Loading model...") - model.load_autosplit(cache) - - return ExLlamaV2Model(model, tokenizer, device, cache) + cache = ExLlamaV2Cache(model, max_seq_len=max_seq_len, lazy=True) + model.load_autosplit(cache, progress=True) + + print("Loading tokenizer...") + tokenizer = ExLlamaV2Tokenizer(config) + tokenizer = patch_tokenizer(tokenizer) + max_batch_size = 4 if paged else 1 + + draft_model = None + draft_cache = None + if draft_model_path is not None: + draft_config = ExLlamaV2Config(draft_model_path) + draft_model = ExLlamaV2(draft_config) + + if cache_q4: + draft_cache = ExLlamaV2Cache_Q4( + draft_model, max_seq_len=max_seq_len, lazy=True + ) + else: + draft_cache = ExLlamaV2Cache( + draft_model, max_seq_len=max_seq_len, lazy=True + ) + + # Initialize the generator with all default parameters + generator = ExLlamaV2DynamicGenerator( + model=model, + cache=cache, + draft_model=draft_model, + draft_cache=draft_cache, + tokenizer=tokenizer, + max_batch_size=max_batch_size, + use_ngram_draft=False, + max_chunk_size=max_chunk_size, + paged=paged, + ) + max_seq_len = cache.max_seq_len + return ExLlamaV2Model(generator, tokenizer, max_seq_len) diff --git a/pyproject.toml b/pyproject.toml index ab3ecd775..ac94ecf57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ test = [ "torch", "transformers", "pillow", + "exllamav2", ] serve = [ "vllm>=0.3.0", diff --git a/tests/generate/conftest.py b/tests/generate/conftest.py index ed8830119..abd9c72a4 100644 --- a/tests/generate/conftest.py +++ b/tests/generate/conftest.py @@ -27,9 +27,11 @@ def pytest_collection_modifyitems(config, items): for item in items: if "model_fixture" in item.fixturenames: model_param = item.callspec.params.get("model_fixture", None) - if model_param.startswith( - "model_transformers_vision" - ) or model_param.startswith("model_vllm"): + if ( + model_param.startswith("model_transformers_vision") + or model_param.startswith("model_vllm") + or model_param.startswith("model_exllamav2") + ): item.add_marker(skip_marker) if not is_metal_available(): diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index a96ce8673..b36baf9a4 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -20,6 +20,15 @@ def model_llamacpp(tmp_path_factory): ) +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + @pytest.fixture(scope="session") def model_mlxlm(tmp_path_factory): return models.mlxlm("mlx-community/TinyLlama-1.1B-Chat-v1.0-4bit") @@ -98,6 +107,7 @@ def model_t5(tmp_path_factory): ALL_MODEL_FIXTURES = ( "model_llamacpp", + "model_exllamav2", "model_mlxlm", "model_mlxlm_phi3", "model_transformers_random", diff --git a/tests/generate/test_integration_exllamav2.py b/tests/generate/test_integration_exllamav2.py new file mode 100644 index 000000000..12c4143b3 --- /dev/null +++ b/tests/generate/test_integration_exllamav2.py @@ -0,0 +1,363 @@ +import importlib +from unittest.mock import patch + +import pytest + +import outlines.models as models +from outlines.generate.api import GenerationParameters, SamplingParameters +from outlines.models.exllamav2 import ExLlamaV2Model + + +@pytest.fixture(scope="session") +def model_exllamav2(tmp_path_factory): + return models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_import_error(request, model_fixture): + with patch.dict("sys.modules", {"exllamav2": None}): + with pytest.raises(ImportError): + models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_attributes(request, model_fixture): + model = request.getfixturevalue(model_fixture) + assert hasattr(model, "generator") + assert hasattr(model, "tokenizer") + assert model.tokenizer.convert_token_to_string(1) == 1 + assert hasattr(model, "max_seq_len") + assert isinstance(model.max_seq_len, int) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_prompt_types(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + prompt = ["test"] + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_no_max_tokens(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=None, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_test_stop_at(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + generation_params = GenerationParameters(max_tokens=10, stop_at=["stop"], seed=None) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_generate_multisampling(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + output = model.generate( + prompt, generation_params, structure_logits_processor, sampling_params + ) + assert isinstance(output, list) + assert isinstance(output[0], str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_prepare_generation_parameters(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + exllamav2_params, prompts = model.prepare_generation_parameters( + prompt, generation_params, sampling_params, structure_logits_processor + ) + assert isinstance(exllamav2_params, dict) + assert isinstance(prompts, list) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_prompt_types(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + prompt = ["test"] + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_no_max_tokens(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=None, stop_at=None, seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_test_stop_at(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + generation_params = GenerationParameters(max_tokens=10, stop_at=["stop"], seed=None) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_multisampling(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, stop_at="stop", seed=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, list) + assert isinstance(token[0], str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_model_stream_seed(request, model_fixture): + model = request.getfixturevalue(model_fixture) + prompt = "test" + generation_params = GenerationParameters(max_tokens=10, seed=1, stop_at=None) + structure_logits_processor = None + sampling_params = SamplingParameters( + "multinomial", + 1, + 0.9, + 50, + 1.0, + ) + generator = model.stream( + prompt, generation_params, structure_logits_processor, sampling_params + ) + for token in generator: + assert isinstance(token, str) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_reformat_output(request, model_fixture): + model = request.getfixturevalue(model_fixture) + sampling_params = SamplingParameters( + "multinomial", + 1, + ) + output = "test" + reformatted_output = model.reformat_output(output, sampling_params) + assert reformatted_output == output + output = ["test"] + reformatted_output = model.reformat_output(output, sampling_params) + assert reformatted_output == output[0] + output = ["test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 1, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == "test" + assert reformatted_output[1] == "test" + output = ["test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == "test" + assert reformatted_output[1] == "test" + output = ["test", "test", "test", "test"] + sampling_params = SamplingParameters( + "multinomial", + 2, + ) + reformatted_output = model.reformat_output(output, sampling_params) + assert len(reformatted_output) == 2 + assert reformatted_output[0] == ["test", "test"] + assert reformatted_output[1] == ["test", "test"] + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_max_chunk_size(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + max_chunk_size=128, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_cache_default(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def is_flash_attn_available(): + try: + importlib.import_module("flash_attn") + except (ImportError, AssertionError): + return False + return True + + +@pytest.mark.skipif(not is_flash_attn_available(), reason="flash-attn is not installed") +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_paged(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=True, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_draft_model(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + draft_model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + cache_q4=True, + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_draft_model_cache_default(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + draft_model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + paged=False, + ) + assert isinstance(model, ExLlamaV2Model) + + +@pytest.mark.parametrize("model_fixture", ["model_exllamav2"]) +def test_exl2_set_max_seq_len(request, model_fixture): + model = models.exl2( + model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + max_seq_len=2048, + paged=False, + cache_q4=True, + ) + assert isinstance(model, ExLlamaV2Model) From 6a7eb904f11fb1db7e1b2edad2cb08ae77fb60d3 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 4 Oct 2024 16:02:47 -0400 Subject: [PATCH 2/3] automatically download exl2 model in tests fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works fix exl bug: sometimes piece_to_id not populated, but get_piece_to_id() still works enable exl2 in generate.cfg cleate OutlinesExLlamaV2Tokenizer rather than monkey patching --- outlines/generate/cfg.py | 9 +------- outlines/models/exllamav2.py | 39 +++++++++++++++++++++++---------- tests/generate/test_generate.py | 10 ++++++++- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/outlines/generate/cfg.py b/outlines/generate/cfg.py index 034a65ae5..b677040d5 100644 --- a/outlines/generate/cfg.py +++ b/outlines/generate/cfg.py @@ -4,7 +4,7 @@ SequenceGeneratorAdapter, VisionSequenceGeneratorAdapter, ) -from outlines.models import ExLlamaV2Model, LlamaCpp, OpenAI, TransformersVision +from outlines.models import LlamaCpp, OpenAI, TransformersVision from outlines.samplers import Sampler, multinomial @@ -41,13 +41,6 @@ def cfg_vision(model, cfg_str: str, sampler: Sampler = multinomial()): return VisionSequenceGeneratorAdapter(model, logits_processor, sampler) -@cfg.register(ExLlamaV2Model) -def cfg_exllamav2(model, cfg_str: str, sampler: Sampler = multinomial()): - raise NotImplementedError( - "Not yet available, track progress in https://github.com/dottxt-ai/outlines/pull/1010" - ) - - @cfg.register(LlamaCpp) def cfg_llamacpp(model, cfg_str: str, sampler: Sampler = multinomial()): raise NotImplementedError("Not yet available due to bug in llama_cpp tokenizer") diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index f06b7e46e..821d4e591 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -1,12 +1,13 @@ import dataclasses from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, TypedDict, Union +import torch from typing_extensions import Unpack from outlines.generate.api import GenerationParameters, SamplingParameters if TYPE_CHECKING: - from exllamav2 import ExLlamaV2Tokenizer + import torch.LongTensor from exllamav2.generator import ExLlamaV2DynamicGenerator, ExLlamaV2Sampler @@ -18,13 +19,33 @@ class ExllamaV2Params(TypedDict, total=False): max_new_tokens: List[int] +class OutlinesExLlamaV2Tokenizer: + def __init__(self, tokenizer): + self.exl2_tokenizer = tokenizer + self.vocabulary = self.exl2_tokenizer.get_piece_to_id_dict() + self.special_tokens = set(self.exl2_tokenizer.extended_piece_to_id) + self.eos_token_id = self.exl2_tokenizer.eos_token_id + + def convert_token_to_string(self, token): + return token + + def decode(self, token_ids: "torch.LongTensor") -> List[str]: + decoded = self.exl2_tokenizer.decode( + torch.tensor(token_ids), + decode_special_tokens=False, + ) + if isinstance(decoded, str): + return [decoded] + return decoded + + class ExLlamaV2Model: """Represents a `exl2` model.""" def __init__( self, generator: "ExLlamaV2DynamicGenerator", - tokenizer: "ExLlamaV2Tokenizer", + tokenizer: "OutlinesExLlamaV2Tokenizer", max_seq_len: int, ): self.generator = generator @@ -220,14 +241,6 @@ def token_generator() -> Iterator[str]: return token_generator() -# Taken from https://github.com/lapp0/exllamav2/pull/1/files#diff-26f303de07c10aad998e33d3df52581643673a598162cc4b35ef051f52d7c60b -def patch_tokenizer(tokenizer): - tokenizer.vocabulary = tokenizer.piece_to_id - tokenizer.special_tokens = set(tokenizer.extended_piece_to_id) - tokenizer.convert_token_to_string = lambda t: t - return tokenizer - - def exl2( model_path: str, draft_model_path: Optional[str] = None, @@ -306,7 +319,6 @@ def exl2( print("Loading tokenizer...") tokenizer = ExLlamaV2Tokenizer(config) - tokenizer = patch_tokenizer(tokenizer) max_batch_size = 4 if paged else 1 draft_model = None @@ -337,4 +349,7 @@ def exl2( paged=paged, ) max_seq_len = cache.max_seq_len - return ExLlamaV2Model(generator, tokenizer, max_seq_len) + + outlines_tokenizer = OutlinesExLlamaV2Tokenizer(tokenizer) + outlines_exl2_model = ExLlamaV2Model(generator, outlines_tokenizer, max_seq_len) + return outlines_exl2_model diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index b36baf9a4..9c288c21e 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -22,8 +22,16 @@ def model_llamacpp(tmp_path_factory): @pytest.fixture(scope="session") def model_exllamav2(tmp_path_factory): + from huggingface_hub import snapshot_download + + tmp_dir = tmp_path_factory.mktemp("model_download") + model_path = snapshot_download( + repo_id="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4.6-exl2", + cache_dir=tmp_dir, + ) + return models.exl2( - model_path="blockblockblock/TinyLlama-1.1B-Chat-v1.0-bpw4-exl2", + model_path=model_path, cache_q4=True, paged=False, ) From a59c26f2b3a72ad34758f9acbfa717595b623af1 Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Fri, 4 Oct 2024 22:24:29 -0400 Subject: [PATCH 3/3] document third party exllamav2 with logits processor --- docs/reference/models/exllamav2.md | 10 +++++++++- outlines/models/exllamav2.py | 4 +++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/reference/models/exllamav2.md b/docs/reference/models/exllamav2.md index afe542112..a4e727840 100644 --- a/docs/reference/models/exllamav2.md +++ b/docs/reference/models/exllamav2.md @@ -1,7 +1,15 @@ # ExllamaV2 +The `outlines.models.exllamav2` model requires a Logits Processor component for compatibility with Outlines structured generation. While ExLlamaV2 doesn't natively support this feature, a third-party fork provides the necessary functionality. You can install it with the following command: + +```bash +pip install git+https://github.com/lapp0/exllamav2@sampler-logits-processor +``` + +Install other requirements: + ```bash -pip install exllamav2 transformers torch +pip install transformers torch ``` *Coming soon* diff --git a/outlines/models/exllamav2.py b/outlines/models/exllamav2.py index 821d4e591..78da796fb 100644 --- a/outlines/models/exllamav2.py +++ b/outlines/models/exllamav2.py @@ -300,7 +300,9 @@ def exl2( except ImportError: raise ImportError( - "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models." + "The `exllamav2`, `transformers` and `torch` libraries needs to be installed in order to use `exllamav2` models. " + "Please run `pip install transformers torch git+https://github.com/lapp0/exllamav2@sampler-logits-processor` " + "Documentation: https://dottxt-ai.github.io/outlines/reference/models/exllamav2/" ) config = ExLlamaV2Config(model_path) if max_chunk_size is not None: