12
12
from pathlib import Path
13
13
from typing import Any , cast
14
14
15
+ from pydantic import ValidationError
15
16
from typing_inspection .introspection import get_literal_values
16
17
17
18
from . import __version__
18
19
from ._run_context import AgentDepsT
19
20
from .agent import Agent
20
21
from .exceptions import UserError
21
- from .messages import ModelMessage
22
+ from .messages import ModelMessage , ModelMessagesTypeAdapter
22
23
from .models import KnownModelName , infer_model
23
24
from .output import OutputDataT
24
25
53
54
"""
54
55
55
56
PROMPT_HISTORY_FILENAME = 'prompt-history.txt'
57
+ LAST_CONVERSATION_FILENAME = 'last-conversation.json'
56
58
57
59
58
60
class SimpleCodeBlock (CodeBlock ):
@@ -146,6 +148,13 @@ def cli( # noqa: C901
146
148
help = 'Which colors to use for code, can be "dark", "light" or any theme from pygments.org/styles/. Defaults to "dark" which works well on dark terminals.' ,
147
149
default = 'dark' ,
148
150
)
151
+ parser .add_argument (
152
+ '-c' ,
153
+ '--continue' ,
154
+ dest = 'continue_' ,
155
+ action = 'store_true' ,
156
+ help = 'Continue last conversation, if any, instead of starting a new one.' ,
157
+ )
149
158
parser .add_argument ('--no-stream' , action = 'store_true' , help = 'Disable streaming from the model' )
150
159
parser .add_argument ('--version' , action = 'store_true' , help = 'Show version and exit' )
151
160
@@ -205,19 +214,42 @@ def cli( # noqa: C901
205
214
else :
206
215
code_theme = args .code_theme # pragma: no cover
207
216
217
+ try :
218
+ history = load_last_conversation () if args .continue_ else None
219
+ except ValidationError :
220
+ console .print (
221
+ '[red]Error loading last conversation, it is corrupted or invalid. Starting a new conversation.[/red]'
222
+ )
223
+ history = None
224
+
208
225
if prompt := cast (str , args .prompt ):
209
226
try :
210
- asyncio .run (ask_agent (agent , prompt , stream , console , code_theme ))
227
+ asyncio .run (ask_agent (agent , prompt , stream , console , code_theme , messages = history ))
211
228
except KeyboardInterrupt :
212
229
pass
213
230
return 0
214
231
215
232
try :
216
- return asyncio .run (run_chat (stream , agent , console , code_theme , prog_name ))
233
+ return asyncio .run (run_chat (stream , agent , console , code_theme , prog_name , history = history ))
217
234
except KeyboardInterrupt : # pragma: no cover
218
235
return 0
219
236
220
237
238
+ def store_last_conversation (messages : list [ModelMessage ], config_dir : Path | None = None ) -> None :
239
+ last_conversation_path = (config_dir or PYDANTIC_AI_HOME ) / LAST_CONVERSATION_FILENAME
240
+ last_conversation_path .parent .mkdir (parents = True , exist_ok = True )
241
+ last_conversation_path .write_bytes (ModelMessagesTypeAdapter .dump_json (messages ))
242
+
243
+
244
+ def load_last_conversation (config_dir : Path | None = None ) -> list [ModelMessage ] | None :
245
+ last_conversation_path = (config_dir or PYDANTIC_AI_HOME ) / LAST_CONVERSATION_FILENAME
246
+
247
+ if not last_conversation_path .exists ():
248
+ return None
249
+
250
+ return ModelMessagesTypeAdapter .validate_json (last_conversation_path .read_text ())
251
+
252
+
221
253
async def run_chat (
222
254
stream : bool ,
223
255
agent : Agent [AgentDepsT , OutputDataT ],
@@ -226,14 +258,15 @@ async def run_chat(
226
258
prog_name : str ,
227
259
config_dir : Path | None = None ,
228
260
deps : AgentDepsT = None ,
261
+ history : list [ModelMessage ] | None = None ,
229
262
) -> int :
230
263
prompt_history_path = (config_dir or PYDANTIC_AI_HOME ) / PROMPT_HISTORY_FILENAME
231
264
prompt_history_path .parent .mkdir (parents = True , exist_ok = True )
232
265
prompt_history_path .touch (exist_ok = True )
233
266
session : PromptSession [Any ] = PromptSession (history = FileHistory (str (prompt_history_path )))
234
267
235
268
multiline = False
236
- messages : list [ModelMessage ] = []
269
+ messages : list [ModelMessage ] = history or []
237
270
238
271
while True :
239
272
try :
@@ -252,7 +285,7 @@ async def run_chat(
252
285
return exit_value
253
286
else :
254
287
try :
255
- messages = await ask_agent (agent , text , stream , console , code_theme , deps , messages )
288
+ messages = await ask_agent (agent , text , stream , console , code_theme , deps , messages , config_dir )
256
289
except CancelledError : # pragma: no cover
257
290
console .print ('[dim]Interrupted[/dim]' )
258
291
except Exception as e : # pragma: no cover
@@ -270,6 +303,7 @@ async def ask_agent(
270
303
code_theme : str ,
271
304
deps : AgentDepsT = None ,
272
305
messages : list [ModelMessage ] | None = None ,
306
+ config_dir : Path | None = None ,
273
307
) -> list [ModelMessage ]:
274
308
status = Status ('[dim]Working on it…[/dim]' , console = console )
275
309
@@ -278,7 +312,10 @@ async def ask_agent(
278
312
result = await agent .run (prompt , message_history = messages , deps = deps )
279
313
content = str (result .output )
280
314
console .print (Markdown (content , code_theme = code_theme ))
281
- return result .all_messages ()
315
+ result_messages = result .all_messages ()
316
+ store_last_conversation (result_messages , config_dir )
317
+
318
+ return result_messages
282
319
283
320
with status , ExitStack () as stack :
284
321
async with agent .iter (prompt , message_history = messages , deps = deps ) as agent_run :
@@ -293,7 +330,10 @@ async def ask_agent(
293
330
live .update (Markdown (str (content ), code_theme = code_theme ))
294
331
295
332
assert agent_run .result is not None
296
- return agent_run .result .all_messages ()
333
+ result_messages = agent_run .result .all_messages ()
334
+ store_last_conversation (result_messages , config_dir )
335
+
336
+ return result_messages
297
337
298
338
299
339
class CustomAutoSuggest (AutoSuggestFromHistory ):
0 commit comments