Skip to content

Commit 0df05f0

Browse files
committed
Add Custom Generation function support
1 parent b0af8d1 commit 0df05f0

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

llmx/configs/config.default.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,6 @@ providers:
147147
model: uukuguy/speechless-llama2-hermes-orca-platypus-13b
148148
device_map: auto
149149
trust_remote_code: true
150+
custom:
151+
name: Custom
152+
description: Custom Text Generation
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Union, List, Dict, Callable
2+
from dataclasses import asdict
3+
from .base_textgen import TextGenerator
4+
from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message
5+
from ...utils import cache_request, num_tokens_from_messages
6+
7+
8+
class CustomTextGenerator(TextGenerator):
9+
def __init__(
10+
self,
11+
text_generation_function: Callable[[str], str],
12+
provider: str = "custom",
13+
**kwargs
14+
):
15+
super().__init__(provider=provider, **kwargs)
16+
self.text_generation_function = text_generation_function
17+
18+
def generate(
19+
self,
20+
messages: Union[List[Dict], str],
21+
config: TextGenerationConfig = TextGenerationConfig(),
22+
**kwargs
23+
) -> TextGenerationResponse:
24+
use_cache = config.use_cache
25+
messages = self.format_messages(messages)
26+
cache_key = {"messages": messages, "config": asdict(config)}
27+
if use_cache:
28+
response = cache_request(cache=self.cache, params=cache_key)
29+
if response:
30+
return TextGenerationResponse(**response)
31+
32+
generation_response = self.text_generation_function(messages)
33+
response = TextGenerationResponse(
34+
text=[Message(role="system", content=generation_response)],
35+
logprobs=[], # You may need to extract log probabilities from the response if needed
36+
usage={},
37+
config={},
38+
)
39+
40+
if use_cache:
41+
cache_request(
42+
cache=self.cache, params=cache_key, values=asdict(response)
43+
)
44+
45+
return response
46+
47+
def format_messages(self, messages) -> str:
48+
prompt = ""
49+
for message in messages:
50+
if message["role"] == "system":
51+
prompt += message["content"] + "\n"
52+
else:
53+
prompt += message["role"] + ": " + message["content"] + "\n"
54+
55+
return prompt
56+
57+
def count_tokens(self, text) -> int:
58+
return num_tokens_from_messages(text)

llmx/generators/text/textgen.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .palm_textgen import PalmTextGenerator
44
from .cohere_textgen import CohereTextGenerator
55
from .anthropic_textgen import AnthropicTextGenerator
6+
from .custom_textgen import CustomTextGenerator
67
import logging
78

89
logger = logging.getLogger("llmx")
@@ -19,9 +20,11 @@ def sanitize_provider(provider: str):
1920
return "hf"
2021
elif provider.lower() == "anthropic" or provider.lower() == "claude":
2122
return "anthropic"
23+
elif provider.lower() == "custom":
24+
return "custom"
2225
else:
2326
raise ValueError(
24-
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
27+
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'."
2528
)
2629

2730

@@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs):
5861
return CohereTextGenerator(**kwargs)
5962
elif provider.lower() == "anthropic":
6063
return AnthropicTextGenerator(**kwargs)
64+
elif provider.lower() == "custom":
65+
return CustomTextGenerator(**kwargs)
6166
elif provider.lower() == "hf":
6267
try:
6368
import transformers
@@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs):
8085

8186
else:
8287
raise ValueError(
83-
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
88+
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'."
8489
)

tests/test_generators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,15 @@ def test_hf_local():
7474

7575
assert ("paris" in answer.lower())
7676
assert len(hf_local_response.text) == 2
77+
78+
def test_custom():
79+
custom_gen = llm(
80+
provider="custom",
81+
text_generation_function=lambda text: "paris",
82+
)
83+
84+
custom_response = custom_gen.generate(messages, config=config)
85+
answer = custom_response.text[0].content
86+
87+
assert ("paris" in answer.lower())
88+
assert len(custom_response.text) == 1

0 commit comments

Comments
 (0)