Skip to content

Commit ee2d46d

Browse files
committed
Add tests
1 parent 3f41578 commit ee2d46d

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

tests/test_cli.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import sys
22
import types
33
from io import StringIO
4+
from pathlib import Path
45
from typing import Any, Callable
56

67
import pytest
@@ -22,7 +23,7 @@
2223
from prompt_toolkit.output import DummyOutput
2324
from prompt_toolkit.shortcuts import PromptSession
2425

25-
from pydantic_ai._cli import cli, cli_agent, handle_slash_command
26+
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, PYDANTIC_AI_HOME, cli, cli_agent, handle_slash_command
2627
from pydantic_ai.models.openai import OpenAIModel
2728

2829
pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
@@ -56,6 +57,16 @@ def _create_test_module(**namespace: Any) -> None:
5657
del sys.modules['test_module']
5758

5859

60+
@pytest.fixture
61+
def emtpy_last_conversation_path():
62+
path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME
63+
64+
if path.exists():
65+
path.unlink()
66+
67+
return path
68+
69+
5970
def test_agent_flag(
6071
capfd: CaptureFixture[str],
6172
mocker: MockerFixture,
@@ -162,6 +173,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
162173
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])
163174

164175

176+
@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
177+
def test_cli_continue_last_conversation(
178+
args: list[str],
179+
capfd: CaptureFixture[str],
180+
env: TestEnv,
181+
emtpy_last_conversation_path: Path,
182+
):
183+
env.set('OPENAI_API_KEY', 'test')
184+
with cli_agent.override(model=TestModel(custom_output_text='# world')):
185+
assert cli(args) == 0
186+
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
187+
assert emtpy_last_conversation_path.exists()
188+
content = emtpy_last_conversation_path.read_text()
189+
assert content
190+
191+
assert cli(args) == 0
192+
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
193+
assert emtpy_last_conversation_path.exists()
194+
# verity that new content is appended to the file
195+
assert len(emtpy_last_conversation_path.read_text()) > len(content)
196+
197+
198+
@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
199+
def test_cli_continue_last_conversation_corrupted_file(
200+
args: list[str],
201+
capfd: CaptureFixture[str],
202+
env: TestEnv,
203+
emtpy_last_conversation_path: Path,
204+
):
205+
env.set('OPENAI_API_KEY', 'test')
206+
emtpy_last_conversation_path.write_text('not a valid json')
207+
with cli_agent.override(model=TestModel(custom_output_text='# world')):
208+
assert cli(args) == 0
209+
assert capfd.readouterr().out.splitlines() == snapshot(
210+
[
211+
IsStr(),
212+
'Error loading last conversation, it is corrupted or invalid. Starting a new ',
213+
'conversation.',
214+
'# world',
215+
]
216+
)
217+
assert emtpy_last_conversation_path.exists()
218+
assert emtpy_last_conversation_path.read_text()
219+
220+
165221
def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
166222
env.set('OPENAI_API_KEY', 'test')
167223
with create_pipe_input() as inp:

0 commit comments

Comments
 (0)