Skip to content

Commit 8c8d174

Browse files
committed
refactor(init): new prompt parser
Signed-off-by: Tomas Slusny <[email protected]>
1 parent 7b15d03 commit 8c8d174

File tree

2 files changed

+207
-85
lines changed

2 files changed

+207
-85
lines changed

lua/CopilotChat/init.lua

Lines changed: 55 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@ local client = require('CopilotChat.client')
55
local constants = require('CopilotChat.constants')
66
local notify = require('CopilotChat.notify')
77
local utils = require('CopilotChat.utils')
8+
local prompts = require('CopilotChat.prompts')
89

9-
local WORD = '([^%s:]+)'
10-
local WORD_NO_INPUT = '([^%s]+)'
11-
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
12-
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'
1310
local BLOCK_OUTPUT_FORMAT = '```%s\n%s\n```'
1411

1512
---@class CopilotChat
@@ -315,10 +312,11 @@ function M.resolve_functions(prompt, config)
315312
tools[tool.name] = tool
316313
end
317314

315+
local refs = prompts.parse(prompt)
316+
local found_tools = utils.to_table(config.tools)
318317
local enabled_tools = {}
319318
local resolved_resources = {}
320319
local resolved_tools = {}
321-
local matches = utils.to_table(config.tools)
322320
local tool_calls = {}
323321
for _, message in ipairs(M.chat.messages) do
324322
if message.tool_calls then
@@ -328,54 +326,33 @@ function M.resolve_functions(prompt, config)
328326
end
329327
end
330328

331-
-- Check for @tool pattern to find enabled tools
332-
prompt = prompt:gsub('@' .. WORD, function(match)
333-
for name, tool in pairs(M.config.functions) do
334-
if name == match or tool.group == match then
335-
table.insert(matches, match)
336-
return ''
329+
-- Find enabled tools from @ references
330+
for _, ref in ipairs(refs) do
331+
if ref.type == 'function_reference' then
332+
for name, tool in pairs(M.config.functions) do
333+
if name == ref.value or tool.group == ref.value then
334+
table.insert(found_tools, ref.value)
335+
end
337336
end
338337
end
339-
return '@' .. match
340-
end)
341-
for _, match in ipairs(matches) do
338+
end
339+
340+
-- Convert tool names to tool objects
341+
for _, match in ipairs(found_tools) do
342342
for name, tool in pairs(M.config.functions) do
343343
if name == match or tool.group == match then
344344
table.insert(enabled_tools, tools[name])
345345
end
346346
end
347347
end
348348

349-
local matches = utils.ordered_map()
350-
351-
-- Check for #word:`input` pattern
352-
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do
353-
local pattern = string.format('#%s:`%s`', word, input)
354-
matches:set(pattern, {
355-
word = word,
356-
input = input,
357-
})
358-
end
359-
360-
-- Check for #word:input pattern
361-
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do
362-
local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input)
363-
matches:set(pattern, {
364-
word = word,
365-
input = input,
366-
})
367-
end
368-
369-
-- Check for ##word:input pattern
370-
for word in prompt:gmatch('##' .. WORD_NO_INPUT) do
371-
local pattern = string.format('##%s', word)
372-
matches:set(pattern, {
373-
word = word,
374-
})
375-
end
349+
prompt = prompts.replace(prompt, refs, function(ref)
350+
if ref.type ~= 'function_call' then
351+
return
352+
end
376353

377-
-- Resolve each function reference
378-
local function expand_function(name, input)
354+
local name = ref.value
355+
local input = ref.input
379356
notify.publish(notify.STATUS, 'Running function: ' .. name)
380357

381358
local tool_id = nil
@@ -448,17 +425,7 @@ function M.resolve_functions(prompt, config)
448425
end
449426

450427
return result
451-
end
452-
453-
-- Resolve and process all tools
454-
for _, pattern in ipairs(matches:keys()) do
455-
if not utils.empty(pattern) then
456-
local match = matches:get(pattern)
457-
local out = expand_function(match.word, match.input) or pattern
458-
out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub
459-
prompt = prompt:gsub(vim.pesc(pattern), out, 1)
460-
end
461-
end
428+
end)
462429

463430
return enabled_tools, resolved_resources, resolved_tools, prompt
464431
end
@@ -479,45 +446,45 @@ function M.resolve_prompt(prompt, config)
479446
local depth = 0
480447
local MAX_DEPTH = 10
481448

482-
local function resolve(inner_config, inner_prompt)
483-
if depth >= MAX_DEPTH then
449+
local function resolve_prompt_template(inner_config, inner_prompt)
450+
if depth >= MAX_DEPTH or not inner_prompt then
484451
return inner_config, inner_prompt
485452
end
453+
486454
depth = depth + 1
487455

488-
inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match)
489-
local p = prompts_to_use[match]
490-
if p then
491-
local resolved_config, resolved_prompt = resolve(p, p.prompt or '')
492-
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
493-
return resolved_prompt
456+
for _, ref in ipairs(prompts.parse(inner_prompt)) do
457+
if ref.type == 'prompt' then
458+
vim.print(ref)
459+
local template = prompts_to_use[ref.value]
460+
if template then
461+
local resolved_config, resolved_prompt = resolve_prompt_template(template, template.prompt)
462+
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
463+
if resolved_prompt then
464+
inner_prompt = inner_prompt:sub(1, ref.start_pos - 1)
465+
.. resolved_prompt
466+
.. inner_prompt:sub(ref.end_pos + 1)
467+
end
468+
end
494469
end
495-
496-
return '/' .. match
497-
end)
470+
end
498471

499472
depth = depth - 1
500473
return inner_config, inner_prompt
501474
end
502475

503-
local function resolve_system_prompt(system_prompt)
504-
if type(system_prompt) == 'function' then
505-
local ok, result = pcall(system_prompt)
506-
if not ok then
507-
log.warn('Failed to resolve system prompt function: ' .. result)
508-
return nil
509-
end
510-
return result
511-
end
512-
513-
return system_prompt
514-
end
515-
516476
config = vim.tbl_deep_extend('force', M.config, config or {})
517-
config, prompt = resolve(config, prompt or '')
477+
config, prompt = resolve_prompt_template(config, prompt)
478+
prompt = prompt or ''
518479

519480
if config.system_prompt then
520-
config.system_prompt = resolve_system_prompt(config.system_prompt)
481+
if type(config.system_prompt) == 'function' then
482+
---@diagnostic disable-next-line: param-type-mismatch
483+
local ok, result = pcall(config.system_prompt)
484+
if ok then
485+
config.system_prompt = result
486+
end
487+
end
521488

522489
if M.config.prompts[config.system_prompt] then
523490
-- Name references are good for making system prompt auto sticky
@@ -547,13 +514,16 @@ function M.resolve_model(prompt, config)
547514
return model.id
548515
end, list_models())
549516

517+
local refs = prompts.parse(prompt)
550518
local selected_model = config.model or ''
551-
prompt = prompt:gsub('%$' .. WORD, function(match)
552-
if vim.tbl_contains(models, match) then
553-
selected_model = match
554-
return ''
519+
520+
prompt = prompts.replace(prompt, refs, function(ref)
521+
if ref.type == 'model' then
522+
if vim.tbl_contains(models, ref.value) then
523+
selected_model = ref.value
524+
return ''
525+
end
555526
end
556-
return '$' .. match
557527
end)
558528

559529
return selected_model, prompt

lua/CopilotChat/prompts.lua

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
local M = {}
2+
3+
local WORD = '([^%s:]+)'
4+
local WORD_NO_INPUT = '([^%s]+)'
5+
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
6+
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'
7+
8+
---@class CopilotChat.prompts.Reference
9+
---@field type 'model'|'function'|'function_call'|'resource'|'sticky'|'prompt'
10+
---@field value string
11+
---@field input? string
12+
---@field start_pos integer
13+
---@field end_pos integer
14+
15+
--- Parse all references from a prompt string, tracking positions.
16+
---@param prompt string
17+
---@return CopilotChat.prompts.Reference[] refs
18+
function M.parse(prompt)
19+
local refs = {}
20+
21+
-- $model
22+
for s, value, e in prompt:gmatch('()%$' .. WORD .. '()') do
23+
table.insert(refs, {
24+
type = 'model',
25+
value = value,
26+
start_pos = s,
27+
end_pos = e - 1,
28+
})
29+
end
30+
31+
-- @function
32+
for s, value, e in prompt:gmatch('()@' .. WORD .. '()') do
33+
table.insert(refs, {
34+
type = 'function_reference',
35+
value = value,
36+
start_pos = s,
37+
end_pos = e - 1,
38+
})
39+
end
40+
41+
-- #function_call
42+
local function function_call_matches(str)
43+
local matches = {}
44+
-- #function_call:`input` (quoted)
45+
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_QUOTED .. '()') do
46+
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
47+
end
48+
-- #function_call:input (unquoted)
49+
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_UNQUOTED .. '()') do
50+
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
51+
end
52+
-- #function_call (no input)
53+
for s, value, e in str:gmatch('()#' .. WORD_NO_INPUT .. '()') do
54+
table.insert(matches, { s = s, e = e - 1, value = value, input = nil })
55+
end
56+
return matches
57+
end
58+
for _, m in ipairs(function_call_matches(prompt)) do
59+
table.insert(refs, {
60+
type = 'function_call',
61+
value = m.value,
62+
input = m.input or nil,
63+
start_pos = m.s,
64+
end_pos = m.e,
65+
})
66+
end
67+
68+
-- ##resource
69+
for s, value, e in prompt:gmatch('()##' .. WORD_NO_INPUT .. '()') do
70+
table.insert(refs, {
71+
type = 'resource',
72+
value = value,
73+
start_pos = s,
74+
end_pos = e - 1,
75+
})
76+
end
77+
78+
-- > sticky
79+
local function sticky_matches(str)
80+
local matches = {}
81+
-- > sticky (newline)
82+
for s, value, e in str:gmatch('()\n> ([^\n]+)()') do
83+
table.insert(matches, { s = s + 1, e = e - 1, value = value })
84+
end
85+
-- > sticky (start of string)
86+
for s, value, e in str:gmatch('()^> ([^\n]+)()') do
87+
table.insert(matches, { s = s, e = e - 1, value = value })
88+
end
89+
return matches
90+
end
91+
for _, m in ipairs(sticky_matches(prompt)) do
92+
table.insert(refs, {
93+
type = 'sticky',
94+
value = m.value,
95+
start_pos = m.s,
96+
end_pos = m.e,
97+
})
98+
end
99+
100+
-- /prompt
101+
for s, value, e in prompt:gmatch('()/' .. WORD_NO_INPUT .. '()') do
102+
table.insert(refs, {
103+
type = 'prompt',
104+
value = value,
105+
start_pos = s,
106+
end_pos = e - 1,
107+
})
108+
end
109+
110+
local keep = {}
111+
for i, ref in ipairs(refs) do
112+
local contained = false
113+
for j, other in ipairs(refs) do
114+
if i ~= j then
115+
-- Strictly contained
116+
if other.type ~= 'sticky' and ref.start_pos > other.start_pos and ref.end_pos < other.end_pos then
117+
contained = true
118+
break
119+
end
120+
-- Exact match, only keep the first occurrence
121+
if ref.start_pos == other.start_pos and ref.end_pos == other.end_pos and j < i then
122+
contained = true
123+
break
124+
end
125+
end
126+
end
127+
if not contained then
128+
table.insert(keep, ref)
129+
end
130+
end
131+
132+
return keep
133+
end
134+
135+
--- Replace references in the prompt using positions (descending order).
136+
---@param prompt string
137+
---@param refs CopilotChat.prompts.Reference[]
138+
---@param resolver fun(ref: CopilotChat.prompts.Reference): string?
139+
function M.replace(prompt, refs, resolver)
140+
table.sort(refs, function(a, b)
141+
return a.start_pos > b.start_pos
142+
end)
143+
for _, ref in ipairs(refs) do
144+
local output = resolver(ref)
145+
if output then
146+
prompt = prompt:sub(1, ref.start_pos - 1) .. output .. prompt:sub(ref.end_pos + 1)
147+
end
148+
end
149+
return prompt
150+
end
151+
152+
return M

0 commit comments

Comments
 (0)