Skip to content

Commit 7688cfe

Browse files
committed
clai: Add ability to continue last conversation
1 parent 01c550c commit 7688cfe

File tree

2 files changed

+49
-8
lines changed

2 files changed

+49
-8
lines changed

clai/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Either way, running `clai` will start an interactive session where you can chat
5353
## Help
5454

5555
```
56-
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt]
56+
usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [-c] [--no-stream] [--version] [prompt]
5757
5858
Pydantic AI CLI v...
5959
@@ -74,6 +74,7 @@ options:
7474
-l, --list-models List all available models and exit
7575
-t [CODE_THEME], --code-theme [CODE_THEME]
7676
Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.
77+
-c, --continue Continue last conversation, if any, instead of starting a new one.
7778
--no-stream Disable streaming from the model
7879
--version Show version and exit
7980
```

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
from pathlib import Path
1313
from typing import Any, cast
1414

15+
from pydantic import ValidationError
1516
from typing_inspection.introspection import get_literal_values
1617

1718
from . import __version__
1819
from ._run_context import AgentDepsT
1920
from .agent import Agent
2021
from .exceptions import UserError
21-
from .messages import ModelMessage
22+
from .messages import ModelMessage, ModelMessagesTypeAdapter
2223
from .models import KnownModelName, infer_model
2324
from .output import OutputDataT
2425

@@ -53,6 +54,7 @@
5354
"""
5455

5556
PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
57+
LAST_CONVERSATION_FILENAME = 'last-conversation.json'
5658

5759

5860
class SimpleCodeBlock(CodeBlock):
@@ -146,6 +148,13 @@ def cli( # noqa: C901
146148
help='Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.',
147149
default='dark',
148150
)
151+
parser.add_argument(
152+
'-c',
153+
'--continue',
154+
dest='continue_',
155+
action='store_true',
156+
help='Continue last conversation, if any, instead of starting a new one.',
157+
)
149158
parser.add_argument('--no-stream', action='store_true', help='Disable streaming from the model')
150159
parser.add_argument('--version', action='store_true', help='Show version and exit')
151160

@@ -205,19 +214,42 @@ def cli( # noqa: C901
205214
else:
206215
code_theme = args.code_theme # pragma: no cover
207216

217+
try:
218+
history = load_last_conversation() if args.continue_ else None
219+
except ValidationError:
220+
console.print(
221+
'[red]Error loading last conversation, it is corrupted or invalid. Starting a new conversation.[/red]'
222+
)
223+
history = None
224+
208225
if prompt := cast(str, args.prompt):
209226
try:
210-
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme))
227+
asyncio.run(ask_agent(agent, prompt, stream, console, code_theme, messages=history))
211228
except KeyboardInterrupt:
212229
pass
213230
return 0
214231

215232
try:
216-
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name))
233+
return asyncio.run(run_chat(stream, agent, console, code_theme, prog_name, history=history))
217234
except KeyboardInterrupt: # pragma: no cover
218235
return 0
219236

220237

238+
def store_last_conversation(messages: list[ModelMessage], config_dir: Path | None = None) -> None:
239+
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME
240+
last_conversation_path.parent.mkdir(parents=True, exist_ok=True)
241+
last_conversation_path.write_bytes(ModelMessagesTypeAdapter.dump_json(messages))
242+
243+
244+
def load_last_conversation(config_dir: Path | None = None) -> list[ModelMessage] | None:
245+
last_conversation_path = (config_dir or PYDANTIC_AI_HOME) / LAST_CONVERSATION_FILENAME
246+
247+
if not last_conversation_path.exists():
248+
return None
249+
250+
return ModelMessagesTypeAdapter.validate_json(last_conversation_path.read_text())
251+
252+
221253
async def run_chat(
222254
stream: bool,
223255
agent: Agent[AgentDepsT, OutputDataT],
@@ -226,14 +258,15 @@ async def run_chat(
226258
prog_name: str,
227259
config_dir: Path | None = None,
228260
deps: AgentDepsT = None,
261+
history: list[ModelMessage] | None = None,
229262
) -> int:
230263
prompt_history_path = (config_dir or PYDANTIC_AI_HOME) / PROMPT_HISTORY_FILENAME
231264
prompt_history_path.parent.mkdir(parents=True, exist_ok=True)
232265
prompt_history_path.touch(exist_ok=True)
233266
session: PromptSession[Any] = PromptSession(history=FileHistory(str(prompt_history_path)))
234267

235268
multiline = False
236-
messages: list[ModelMessage] = []
269+
messages: list[ModelMessage] = history or []
237270

238271
while True:
239272
try:
@@ -252,7 +285,7 @@ async def run_chat(
252285
return exit_value
253286
else:
254287
try:
255-
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages)
288+
messages = await ask_agent(agent, text, stream, console, code_theme, deps, messages, config_dir)
256289
except CancelledError: # pragma: no cover
257290
console.print('[dim]Interrupted[/dim]')
258291
except Exception as e: # pragma: no cover
@@ -270,6 +303,7 @@ async def ask_agent(
270303
code_theme: str,
271304
deps: AgentDepsT = None,
272305
messages: list[ModelMessage] | None = None,
306+
config_dir: Path | None = None,
273307
) -> list[ModelMessage]:
274308
status = Status('[dim]Working on it…[/dim]', console=console)
275309

@@ -278,7 +312,10 @@ async def ask_agent(
278312
result = await agent.run(prompt, message_history=messages, deps=deps)
279313
content = str(result.output)
280314
console.print(Markdown(content, code_theme=code_theme))
281-
return result.all_messages()
315+
result_messages = result.all_messages()
316+
store_last_conversation(result_messages, config_dir)
317+
318+
return result_messages
282319

283320
with status, ExitStack() as stack:
284321
async with agent.iter(prompt, message_history=messages, deps=deps) as agent_run:
@@ -293,7 +330,10 @@ async def ask_agent(
293330
live.update(Markdown(str(content), code_theme=code_theme))
294331

295332
assert agent_run.result is not None
296-
return agent_run.result.all_messages()
333+
result_messages = agent_run.result.all_messages()
334+
store_last_conversation(result_messages, config_dir)
335+
336+
return result_messages
297337

298338

299339
class CustomAutoSuggest(AutoSuggestFromHistory):

0 commit comments

Comments
 (0)