Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
You can contribute back the changes you want to make.
"""

import re
import dataclasses
from enum import auto, IntEnum
from typing import List, Any, Dict
Expand All @@ -28,6 +29,7 @@ class SeparatorStyle(IntEnum):
PHOENIX = auto()
ROBIN = auto()
JSLM_ALPHA = auto()
RINNA = auto()


@dataclasses.dataclass
Expand Down Expand Up @@ -208,6 +210,15 @@ def get_prompt(self) -> str:
else:
ret += role + ": \n"
return ret
elif self.sep_style == SeparatorStyle.RINNA:
ret = ""
for role, message in self.messages:
if message:
message = re.sub(r'\n+', '<NL>', message)
ret += role + ": " + message + self.sep
else:
ret += role + ": "
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

Expand Down Expand Up @@ -360,6 +371,35 @@ def get_conv_template(name: str) -> Conversation:
)
)

# conv template for matsuo-lab/weblab-10b-instruction-sft
register_conv_template(
Conversation(
name="weblab",
system_message="以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。",
roles=("指示", "応答"),
messages=(),
offset=0,
sep_style=SeparatorStyle.JSLM_ALPHA,
sep="\n\n### ",
stop_str="<|endoftext|>",
add_special_tokens=False,
)
)

# conv tamplate for rinna
register_conv_template(
Conversation(
name="rinna",
system_message="",
roles=("ユーザー", "システム"),
messages=(),
offset=0,
sep_style=SeparatorStyle.RINNA,
sep="<NL>",
add_special_tokens=False,
)
)

# Vicuna v1.1 template
register_conv_template(
Conversation(
Expand Down
48 changes: 47 additions & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,50 @@ def get_default_conv_template(self, model_path:str):
# TODO: (meng) might need to adapt default conv tpl based on model version
return get_conv_template("jslm_alpha")


class WeblabAdapter(BaseModelAdapter):
"Model adapter for weblab-10b-instruction-sft"

def match(self, model_path: str):
return model_path.endswith("weblab-10b-instruction-sft")

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str):
return get_conv_template("weblab")


class RinnaAdapter(BaseModelAdapter):
"Model adapter for rinna"

def match(self, model_path: str):
return model_path.endswith("japanese-gpt-neox-3.6b-instruction-ppo")

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
tokenizer = AutoTokenizer.from_pretrained(
model_path,
use_fast=False,
**from_pretrained_kwargs,
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
return model, tokenizer

def get_default_conv_template(self, model_path: str):
return get_conv_template("rinna")


class VicunaAdapter(BaseModelAdapter):
"Model adapater for Vicuna models (e.g., lmsys/vicuna-7b-v1.3)" ""

Expand Down Expand Up @@ -1430,6 +1474,8 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
# Note: the registration order matters.
# The one registered earlier has a higher matching priority.
register_model_adapter(JSLMAlphaAdapter)
register_model_adapter(WeblabAdapter)
register_model_adapter(RinnaAdapter)
register_model_adapter(PeftModelAdapter)
register_model_adapter(VicunaAdapter)
register_model_adapter(AiroborosAdapter)
Expand Down Expand Up @@ -1527,4 +1573,4 @@ def build_prompt(user_query, inputs="", sep="\n\n### "):
)

out = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(out)
print(out)