|
1 | 1 | import sys
|
2 | 2 | import types
|
3 | 3 | from io import StringIO
|
| 4 | +from pathlib import Path |
4 | 5 | from typing import Any, Callable
|
5 | 6 |
|
6 | 7 | import pytest
|
|
22 | 23 | from prompt_toolkit.output import DummyOutput
|
23 | 24 | from prompt_toolkit.shortcuts import PromptSession
|
24 | 25 |
|
25 |
| - from pydantic_ai._cli import cli, cli_agent, handle_slash_command |
| 26 | + from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, PYDANTIC_AI_HOME, cli, cli_agent, handle_slash_command |
26 | 27 | from pydantic_ai.models.openai import OpenAIModel
|
27 | 28 |
|
28 | 29 | pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
|
@@ -56,6 +57,16 @@ def _create_test_module(**namespace: Any) -> None:
|
56 | 57 | del sys.modules['test_module']
|
57 | 58 |
|
58 | 59 |
|
| 60 | +@pytest.fixture |
| 61 | +def emtpy_last_conversation_path(): |
| 62 | + path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME |
| 63 | + |
| 64 | + if path.exists(): |
| 65 | + path.unlink() |
| 66 | + |
| 67 | + return path |
| 68 | + |
| 69 | + |
59 | 70 | def test_agent_flag(
|
60 | 71 | capfd: CaptureFixture[str],
|
61 | 72 | mocker: MockerFixture,
|
@@ -162,6 +173,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
|
162 | 173 | assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])
|
163 | 174 |
|
164 | 175 |
|
| 176 | +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) |
| 177 | +def test_cli_continue_last_conversation( |
| 178 | + args: list[str], |
| 179 | + capfd: CaptureFixture[str], |
| 180 | + env: TestEnv, |
| 181 | + emtpy_last_conversation_path: Path, |
| 182 | +): |
| 183 | + env.set('OPENAI_API_KEY', 'test') |
| 184 | + with cli_agent.override(model=TestModel(custom_output_text='# world')): |
| 185 | + assert cli(args) == 0 |
| 186 | + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) |
| 187 | + assert emtpy_last_conversation_path.exists() |
| 188 | + content = emtpy_last_conversation_path.read_text() |
| 189 | + assert content |
| 190 | + |
| 191 | + assert cli(args) == 0 |
| 192 | + assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world']) |
| 193 | + assert emtpy_last_conversation_path.exists() |
| 194 | + # verity that new content is appended to the file |
| 195 | + assert len(emtpy_last_conversation_path.read_text()) > len(content) |
| 196 | + |
| 197 | + |
| 198 | +@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']]) |
| 199 | +def test_cli_continue_last_conversation_corrupted_file( |
| 200 | + args: list[str], |
| 201 | + capfd: CaptureFixture[str], |
| 202 | + env: TestEnv, |
| 203 | + emtpy_last_conversation_path: Path, |
| 204 | +): |
| 205 | + env.set('OPENAI_API_KEY', 'test') |
| 206 | + emtpy_last_conversation_path.write_text('not a valid json') |
| 207 | + with cli_agent.override(model=TestModel(custom_output_text='# world')): |
| 208 | + assert cli(args) == 0 |
| 209 | + assert capfd.readouterr().out.splitlines() == snapshot( |
| 210 | + [ |
| 211 | + IsStr(), |
| 212 | + 'Error loading last conversation, it is corrupted or invalid. Starting a new ', |
| 213 | + 'conversation.', |
| 214 | + '# world', |
| 215 | + ] |
| 216 | + ) |
| 217 | + assert emtpy_last_conversation_path.exists() |
| 218 | + assert emtpy_last_conversation_path.read_text() |
| 219 | + |
| 220 | + |
165 | 221 | def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
|
166 | 222 | env.set('OPENAI_API_KEY', 'test')
|
167 | 223 | with create_pipe_input() as inp:
|
|
0 commit comments