Skip to content

clai: Add ability to continue last conversation #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Either way, running `clai` will start an interactive session where you can chat
## Help

```
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c] [--no-stream] [--version] [prompt]

Pydantic AI CLI v...

Expand All @@ -74,6 +74,7 @@ options:
-l, --list-models List all available models and exit
-t [CODE_THEME], --code-theme [CODE_THEME]
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.
-c, --continue Continue last conversation, if any, instead of starting a new one.
--no-stream Disable streaming from the model
--version Show version and exit
```
54 changes: 47 additions & 7 deletions pydantic_ai_slim/pydantic_ai/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from pathlib import Path
from typing import Any, cast

from pydantic import ValidationError
from typing_inspection.introspection import get_literal_values

from . import __version__
from ._run_context import AgentDepsT
from .agent import Agent
from .exceptions import UserError
from .messages import ModelMessage
from .messages import ModelMessage, ModelMessagesTypeAdapter
from .models import KnownModelName, infer_model
from .output import OutputDataT

Expand Down Expand Up @@ -53,6 +54,7 @@
"""

PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
LAST_CONVERSATION_FILENAME = 'last-conversation.json'


class SimpleCodeBlock(CodeBlock):
Expand Down Expand Up @@ -146,6 +148,13 @@ def cli( # noqa: C901
help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.',
default='dark',
)
parser.add_argument(
'-c',
'--continue',
dest='continue_',
action='store_true',
help='Continue last conversation, if any, instead of starting a new one.',
)
parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model')
parser.add_argument('--version', action='store_true', help='Show version and exit')

Expand Down Expand Up @@ -205,19 +214,42 @@ def cli( # noqa: C901
else:
code_theme = args.code_theme # pragma: no cover

try:
history = load_last_conversation() if args.continue_ else None
except ValidationError:
console.print(
'[red]Error loading last conversation, it is corrupted or invalid.\nStarting a new conversation.[/red]'
)
history = None

if prompt := cast(str, args.prompt):
try:
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history))
except KeyboardInterrupt:
pass
return 0

try:
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name, history=history))
except KeyboardInterrupt: # pragma: no cover
return 0


def store_last_conversation(messages: list[ModelMessage], config_dir: Path | None = None) -> None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME
last_conversation_path.parent.mkdir(parents=True, exist_ok=True)
last_conversation_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages))


def load_last_conversation(config_dir: Path | None = None) -> list[ModelMessage] | None:
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME

if not last_conversation_path.exists():
return None

return ModelMessagesTypeAdapter.validate_json(last_conversation_path.read_text())


async def run_chat(
stream: bool,
agent: Agent[AgentDepsT, OutputDataT],
Expand All @@ -226,14 +258,15 @@ async def run_chat(
prog_name: str,
config_dir: Path | None = None,
deps: AgentDepsT = None,
history: list[ModelMessage] | None = None,
) -> int:
prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
prompt_history_path.touch(exist_ok=True)
session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path)))

multiline = False
messages: list[ModelMessage] = []
messages: list[ModelMessage] = history or []

while True:
try:
Expand All @@ -252,7 +285,7 @@ async def run_chat(
return exit_value
else:
try:
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, config_dir)
except CancelledError: # pragma: no cover
console.print('[dim]Interrupted[/dim]')
except Exception as e: # pragma: no cover
Expand All @@ -270,6 +303,7 @@ async def ask_agent(
code_theme: str,
deps: AgentDepsT = None,
messages: list[ModelMessage] | None = None,
config_dir: Path | None = None,
) -> list[ModelMessage]:
status = Status('[dim]Working on it…[/dim]', console=console)

Expand All @@ -278,7 +312,10 @@ async def ask_agent(
result = await agent.run(prompt, message_history=messages, deps=deps)
content = str(result.output)
console.print(Markdown(content, code_theme=code_theme))
return result.all_messages()
result_messages = result.all_messages()
store_last_conversation(result_messages, config_dir)

return result_messages

with status, ExitStack() as stack:
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
Expand All @@ -293,7 +330,10 @@ async def ask_agent(
live.update(Markdown(str(content), code_theme=code_theme))

assert agent_run.result is not None
return agent_run.result.all_messages()
result_messages = agent_run.result.all_messages()
store_last_conversation(result_messages, config_dir)

return result_messages


class CustomAutoSuggest(AutoSuggestFromHistory):
Expand Down
64 changes: 60 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import types
from io import StringIO
from pathlib import Path
from typing import Any, Callable

import pytest
Expand All @@ -22,7 +23,7 @@
from prompt_toolkit.output import DummyOutput
from prompt_toolkit.shortcuts import PromptSession

from pydantic_ai._cli import cli, cli_agent, handle_slash_command
from pydantic_ai._cli import LAST_CONVERSATION_FILENAME, PYDANTIC_AI_HOME, cli, cli_agent, handle_slash_command
from pydantic_ai.models.openai import OpenAIModel

pytestmark = pytest.mark.skipif(not imports_successful(), reason='install cli extras to run cli tests')
Expand Down Expand Up @@ -56,6 +57,16 @@ def _create_test_module(**namespace: Any) -> None:
del sys.modules['test_module']


@pytest.fixture
def emtpy_last_conversation_path():
path = PYDANTIC_AI_HOME / LAST_CONVERSATION_FILENAME

if path.exists():
path.unlink()

return path


def test_agent_flag(
capfd: CaptureFixture[str],
mocker: MockerFixture,
Expand Down Expand Up @@ -163,6 +174,51 @@ def test_cli_prompt(capfd: CaptureFixture[str], env: TestEnv):
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# result', '', 'py', 'x = 1', '/py'])


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
emtpy_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert emtpy_last_conversation_path.exists()
content = emtpy_last_conversation_path.read_text()
assert content

assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot([IsStr(), '# world'])
assert emtpy_last_conversation_path.exists()
# verity that new content is appended to the file
assert len(emtpy_last_conversation_path.read_text()) > len(content)


@pytest.mark.parametrize('args', [['hello', '-c'], ['hello', '--continue']])
def test_cli_continue_last_conversation_corrupted_file(
args: list[str],
capfd: CaptureFixture[str],
env: TestEnv,
emtpy_last_conversation_path: Path,
):
env.set('OPENAI_API_KEY', 'test')
emtpy_last_conversation_path.write_text('not a valid json')
with cli_agent.override(model=TestModel(custom_output_text='# world')):
assert cli(args) == 0
assert capfd.readouterr().out.splitlines() == snapshot(
[
IsStr(),
'Error loading last conversation, it is corrupted or invalid.',
'Starting a new conversation.',
'# world',
]
)
assert emtpy_last_conversation_path.exists()
assert emtpy_last_conversation_path.read_text()


def test_chat(capfd: CaptureFixture[str], mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
with create_pipe_input() as inp:
Expand Down Expand Up @@ -228,21 +284,21 @@ def test_code_theme_unset(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli([])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai', history=None)


def test_code_theme_light(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=light'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai')
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'default', 'pai', history=None)


def test_code_theme_dark(mocker: MockerFixture, env: TestEnv):
env.set('OPENAI_API_KEY', 'test')
mock_run_chat = mocker.patch('pydantic_ai._cli.run_chat')
cli(['--code-theme=dark'])
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai')
mock_run_chat.assert_awaited_once_with(True, IsInstance(Agent), IsInstance(Console), 'monokai', 'pai', history=None)


def test_agent_to_cli_sync(mocker: MockerFixture, env: TestEnv):
Expand Down
Loading