diff --git a/fastchat/conversation.py b/fastchat/conversation.py index 0667aaadd..2bbd463b7 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -5,6 +5,7 @@ You can contribute back the changes you want to make. """ +import re import dataclasses from enum import auto, IntEnum from typing import List, Any, Dict @@ -28,6 +29,7 @@ class SeparatorStyle(IntEnum): PHOENIX = auto() ROBIN = auto() JSLM_ALPHA = auto() + RINNA = auto() @dataclasses.dataclass @@ -208,6 +210,15 @@ def get_prompt(self) -> str: else: ret += role + ": \n" return ret + elif self.sep_style == SeparatorStyle.RINNA: + ret = "" + for role, message in self.messages: + if message: + message = re.sub(r'\n+', '', message) + ret += role + ": " + message + self.sep + else: + ret += role + ": " + return ret else: raise ValueError(f"Invalid style: {self.sep_style}") @@ -360,6 +371,35 @@ def get_conv_template(name: str) -> Conversation: ) ) +# conv template for matsuo-lab/weblab-10b-instruction-sft +register_conv_template( + Conversation( + name="weblab", + system_message="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。", + roles=("指示", "応答"), + messages=(), + offset=0, + sep_style=SeparatorStyle.JSLM_ALPHA, + sep="\n\n### ", + stop_str="<|endoftext|>", + add_special_tokens=False, + ) +) + +# conv tamplate for rinna +register_conv_template( + Conversation( + name="rinna", + system_message="", + roles=("ユーザー", "システム"), + messages=(), + offset=0, + sep_style=SeparatorStyle.RINNA, + sep="", + add_special_tokens=False, + ) +) + # Vicuna v1.1 template register_conv_template( Conversation( diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index b24fe23a7..8c51c72f3 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -547,6 +547,50 @@ def get_default_conv_template(self, model_path:str): # TODO: (meng) might need to adapt default conv tpl based on model version return get_conv_template("jslm_alpha") + +class WeblabAdapter(BaseModelAdapter): + "Model adapter for weblab-10b-instruction-sft" + + def match(self, model_path: str): + return model_path.endswith("weblab-10b-instruction-sft") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str): + return get_conv_template("weblab") + + +class RinnaAdapter(BaseModelAdapter): + "Model adapter for rinna" + + def match(self, model_path: str): + return model_path.endswith("japanese-gpt-neox-3.6b-instruction-ppo") + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=False, + **from_pretrained_kwargs, + ) + model = AutoModelForCausalLM.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + return model, tokenizer + + def get_default_conv_template(self, model_path: str): + return get_conv_template("rinna") + + class VicunaAdapter(BaseModelAdapter): "Model adapater for Vicuna models (e.g., lmsys/vicuna-7b-v1.3)" "" @@ -1430,6 +1474,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation: # Note: the registration order matters. # The one registered earlier has a higher matching priority. register_model_adapter(JSLMAlphaAdapter) +register_model_adapter(WeblabAdapter) +register_model_adapter(RinnaAdapter) register_model_adapter(PeftModelAdapter) register_model_adapter(VicunaAdapter) register_model_adapter(AiroborosAdapter) @@ -1527,4 +1573,4 @@ def build_prompt(user_query, inputs="", sep="\n\n### "): ) out = tokenizer.decode(tokens[0], skip_special_tokens=True) - print(out) \ No newline at end of file + print(out)