Skip to content

Commit 9f644b6

Browse files
Actually test the Python translator (#240)
* Add a test for the built-in translator and add the snapshots. * Ensure the repair prompt is actually appended to our messages and update snapshots. * Add uncommitted file.
1 parent 66fd7bb commit 9f644b6

File tree

4 files changed

+204
-12
lines changed

4 files changed

+204
-12
lines changed

python/examples/healthData/translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def __init__(
2727
self._additional_agent_instructions = additional_agent_instructions
2828

2929
@override
30-
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
31-
result = await super().translate(request=request, prompt_preamble=prompt_preamble)
30+
async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
31+
result = await super().translate(input=input, prompt_preamble=prompt_preamble)
3232
if not isinstance(result, Failure):
3333
self._chat_history.append(ChatMessage(source="assistant", body=result.value))
3434
return result

python/src/typechat/_internal/translator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,31 +49,33 @@ def __init__(
4949
self._type_name = conversion_result.typescript_type_reference
5050
self._schema_str = conversion_result.typescript_schema_str
5151

52-
async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
52+
async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]:
5353
"""
5454
Translates a natural language request into an object of type `T`. If the JSON object returned by
5555
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
5656
The prompt for the subsequent attempts will include the diagnostics produced for the prior attempt.
5757
This often helps produce a valid instance.
5858
5959
Args:
60-
request: A natural language request.
60+
input: A natural language request.
6161
prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\
6262
If a string is given, it is converted to a single "user" role prompt section.
6363
"""
64-
request = self._create_request_prompt(request)
6564

66-
prompt: str | list[PromptSection]
67-
if prompt_preamble is None:
68-
prompt = request
69-
else:
65+
messages: list[PromptSection] = []
66+
67+
messages.append({"role": "user", "content": input})
68+
if prompt_preamble:
7069
if isinstance(prompt_preamble, str):
7170
prompt_preamble = [{"role": "user", "content": prompt_preamble}]
72-
prompt = [*prompt_preamble, {"role": "user", "content": request}]
71+
else:
72+
messages.extend(prompt_preamble)
73+
74+
messages.append({"role": "user", "content": self._create_request_prompt(input)})
7375

7476
num_repairs_attempted = 0
7577
while True:
76-
completion_response = await self.model.complete(prompt)
78+
completion_response = await self.model.complete(messages)
7779
if isinstance(completion_response, Failure):
7880
return completion_response
7981

@@ -93,7 +95,7 @@ async def translate(self, request: str, *, prompt_preamble: str | list[PromptSec
9395
if num_repairs_attempted >= self._max_repair_attempts:
9496
return Failure(error_message)
9597
num_repairs_attempted += 1
96-
request = f"{text_response}\n{self._create_repair_prompt(error_message)}"
98+
messages.append({"role": "user", "content": self._create_repair_prompt(error_message)})
9799

98100
def _create_request_prompt(self, intent: str) -> str:
99101
prompt = f"""
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# serializer version: 1
2+
# name: test_translator_with_immediate_pass
3+
list([
4+
dict({
5+
'kind': 'CLIENT REQUEST',
6+
'payload': list([
7+
dict({
8+
'content': 'Get me stuff.',
9+
'role': 'user',
10+
}),
11+
dict({
12+
'content': '''
13+
14+
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
15+
```
16+
interface ExampleABC {
17+
a: string;
18+
b: boolean;
19+
c: number;
20+
}
21+
22+
```
23+
The following is a user request:
24+
'''
25+
Get me stuff.
26+
'''
27+
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:
28+
29+
''',
30+
'role': 'user',
31+
}),
32+
]),
33+
}),
34+
dict({
35+
'kind': 'MODEL RESPONSE',
36+
'payload': '{ "a": "hello", "b": true, "c": 1234 }',
37+
}),
38+
])
39+
# ---
40+
# name: test_translator_with_single_failure
41+
list([
42+
dict({
43+
'kind': 'CLIENT REQUEST',
44+
'payload': list([
45+
dict({
46+
'content': 'Get me stuff.',
47+
'role': 'user',
48+
}),
49+
dict({
50+
'content': '''
51+
52+
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
53+
```
54+
interface ExampleABC {
55+
a: string;
56+
b: boolean;
57+
c: number;
58+
}
59+
60+
```
61+
The following is a user request:
62+
'''
63+
Get me stuff.
64+
'''
65+
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:
66+
67+
''',
68+
'role': 'user',
69+
}),
70+
dict({
71+
'content': '''
72+
73+
The above JSON object is invalid for the following reason:
74+
'''
75+
Validation path `c` failed for value `{"a": "hello", "b": true}` because:
76+
Field required
77+
'''
78+
The following is a revised JSON object:
79+
80+
''',
81+
'role': 'user',
82+
}),
83+
]),
84+
}),
85+
dict({
86+
'kind': 'MODEL RESPONSE',
87+
'payload': '{ "a": "hello", "b": true }',
88+
}),
89+
dict({
90+
'kind': 'CLIENT REQUEST',
91+
'payload': list([
92+
dict({
93+
'content': 'Get me stuff.',
94+
'role': 'user',
95+
}),
96+
dict({
97+
'content': '''
98+
99+
You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions:
100+
```
101+
interface ExampleABC {
102+
a: string;
103+
b: boolean;
104+
c: number;
105+
}
106+
107+
```
108+
The following is a user request:
109+
'''
110+
Get me stuff.
111+
'''
112+
The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined:
113+
114+
''',
115+
'role': 'user',
116+
}),
117+
dict({
118+
'content': '''
119+
120+
The above JSON object is invalid for the following reason:
121+
'''
122+
Validation path `c` failed for value `{"a": "hello", "b": true}` because:
123+
Field required
124+
'''
125+
The following is a revised JSON object:
126+
127+
''',
128+
'role': 'user',
129+
}),
130+
]),
131+
}),
132+
dict({
133+
'kind': 'MODEL RESPONSE',
134+
'payload': '{ "a": "hello", "b": true, "c": 1234 }',
135+
}),
136+
])
137+
# ---

python/tests/test_translator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
import asyncio
3+
from dataclasses import dataclass
4+
from typing_extensions import Any, Iterator, Literal, TypedDict, override
5+
import typechat
6+
7+
class ConvoRecord(TypedDict):
8+
kind: Literal["CLIENT REQUEST", "MODEL RESPONSE"]
9+
payload: str | list[typechat.PromptSection]
10+
11+
class FixedModel(typechat.TypeChatLanguageModel):
12+
responses: Iterator[str]
13+
conversation: list[ConvoRecord]
14+
15+
"A model which responds with one of a series of responses."
16+
def __init__(self, responses: list[str]) -> None:
17+
super().__init__()
18+
self.responses = iter(responses)
19+
self.conversation = []
20+
21+
@override
22+
async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]:
23+
self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt })
24+
response = next(self.responses)
25+
self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response })
26+
return typechat.Success(response)
27+
28+
@dataclass
29+
class ExampleABC:
30+
a: str
31+
b: bool
32+
c: int
33+
34+
v = typechat.TypeChatValidator(ExampleABC)
35+
36+
def test_translator_with_immediate_pass(snapshot: Any):
37+
m = FixedModel([
38+
'{ "a": "hello", "b": true, "c": 1234 }',
39+
])
40+
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC)
41+
asyncio.run(t.translate("Get me stuff."))
42+
43+
assert m.conversation == snapshot
44+
45+
def test_translator_with_single_failure(snapshot: Any):
46+
m = FixedModel([
47+
'{ "a": "hello", "b": true }',
48+
'{ "a": "hello", "b": true, "c": 1234 }',
49+
])
50+
t = typechat.TypeChatJsonTranslator(m, v, ExampleABC)
51+
asyncio.run(t.translate("Get me stuff."))
52+
53+
assert m.conversation == snapshot

0 commit comments

Comments
 (0)