Skip to content

Commit b2bffa7

Browse files
committed
refactor(init): new prompt parser
Signed-off-by: Tomas Slusny <[email protected]>
1 parent 76cc416 commit b2bffa7

File tree

2 files changed

+208
-82
lines changed

2 files changed

+208
-82
lines changed

lua/CopilotChat/init.lua

Lines changed: 56 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@ local client = require('CopilotChat.client')
55
local constants = require('CopilotChat.constants')
66
local notify = require('CopilotChat.notify')
77
local utils = require('CopilotChat.utils')
8+
local prompts = require('CopilotChat.prompts')
89

9-
local WORD = '([^%s:]+)'
10-
local WORD_NO_INPUT = '([^%s]+)'
11-
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
12-
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'
1310
local BLOCK_OUTPUT_FORMAT = '```%s\n%s\n```'
1411

1512
---@class CopilotChat
@@ -315,10 +312,11 @@ function M.resolve_functions(prompt, config)
315312
tools[tool.name] = tool
316313
end
317314

315+
local refs = prompts.parse(prompt)
316+
local found_tools = utils.to_table(config.tools)
318317
local enabled_tools = {}
319318
local resolved_resources = {}
320319
local resolved_tools = {}
321-
local matches = utils.to_table(config.tools)
322320
local tool_calls = {}
323321
for _, message in ipairs(M.chat.messages) do
324322
if message.tool_calls then
@@ -328,53 +326,30 @@ function M.resolve_functions(prompt, config)
328326
end
329327
end
330328

331-
-- Check for @tool pattern to find enabled tools
332-
prompt = prompt:gsub('@' .. WORD, function(match)
333-
for name, tool in pairs(M.config.functions) do
334-
if name == match or tool.group == match then
335-
table.insert(matches, match)
336-
return ''
329+
-- Find enabled tools from @ references
330+
for _, ref in ipairs(refs) do
331+
if ref.type == 'function_reference' then
332+
for name, tool in pairs(M.config.functions) do
333+
if name == ref.value or tool.group == ref.value then
334+
table.insert(found_tools, ref.value)
335+
end
337336
end
338337
end
339-
return '@' .. match
340-
end)
341-
for _, match in ipairs(matches) do
338+
end
339+
340+
-- Convert tool names to tool objects
341+
for _, match in ipairs(found_tools) do
342342
for name, tool in pairs(M.config.functions) do
343343
if name == match or tool.group == match then
344344
table.insert(enabled_tools, tools[name])
345345
end
346346
end
347347
end
348348

349-
local matches = utils.ordered_map()
350-
351-
-- Check for #word:`input` pattern
352-
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_QUOTED) do
353-
local pattern = string.format('#%s:`%s`', word, input)
354-
matches:set(pattern, {
355-
word = word,
356-
input = input,
357-
})
358-
end
359-
360-
-- Check for #word:input pattern
361-
for word, input in prompt:gmatch('#' .. WORD_WITH_INPUT_UNQUOTED) do
362-
local pattern = utils.empty(input) and string.format('#%s', word) or string.format('#%s:%s', word, input)
363-
matches:set(pattern, {
364-
word = word,
365-
input = input,
366-
})
367-
end
368-
369-
-- Check for ##word:input pattern
370-
for word in prompt:gmatch('##' .. WORD_NO_INPUT) do
371-
local pattern = string.format('##%s', word)
372-
matches:set(pattern, {
373-
word = word,
374-
})
375-
end
376-
377-
-- Resolve each function reference
349+
-- Helper to resolve function calls
350+
---@param name string
351+
---@param input string
352+
---@return string|nil
378353
local function expand_function(name, input)
379354
notify.publish(notify.STATUS, 'Running function: ' .. name)
380355

@@ -450,15 +425,11 @@ function M.resolve_functions(prompt, config)
450425
return result
451426
end
452427

453-
-- Resolve and process all tools
454-
for _, pattern in ipairs(matches:keys()) do
455-
if not utils.empty(pattern) then
456-
local match = matches:get(pattern)
457-
local out = expand_function(match.word, match.input) or pattern
458-
out = out:gsub('%%', '%%%%') -- Escape percent signs for gsub
459-
prompt = prompt:gsub(vim.pesc(pattern), out, 1)
428+
prompt = prompts.replace(prompt, refs, function(ref)
429+
if ref.type == 'function_call' then
430+
return expand_function(ref.value, ref.input)
460431
end
461-
end
432+
end)
462433

463434
return enabled_tools, resolved_resources, resolved_tools, prompt
464435
end
@@ -479,45 +450,45 @@ function M.resolve_prompt(prompt, config)
479450
local depth = 0
480451
local MAX_DEPTH = 10
481452

482-
local function resolve(inner_config, inner_prompt)
483-
if depth >= MAX_DEPTH then
453+
local function resolve_prompt_template(inner_config, inner_prompt)
454+
if depth >= MAX_DEPTH or not inner_prompt then
484455
return inner_config, inner_prompt
485456
end
457+
486458
depth = depth + 1
487459

488-
inner_prompt = string.gsub(inner_prompt, '/' .. WORD, function(match)
489-
local p = prompts_to_use[match]
490-
if p then
491-
local resolved_config, resolved_prompt = resolve(p, p.prompt or '')
492-
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
493-
return resolved_prompt
460+
for _, ref in ipairs(prompts.parse(inner_prompt)) do
461+
if ref.type == 'prompt' then
462+
vim.print(ref)
463+
local template = prompts_to_use[ref.value]
464+
if template then
465+
local resolved_config, resolved_prompt = resolve_prompt_template(template, template.prompt)
466+
inner_config = vim.tbl_deep_extend('force', inner_config, resolved_config)
467+
if resolved_prompt then
468+
inner_prompt = inner_prompt:sub(1, ref.start_pos - 1)
469+
.. resolved_prompt
470+
.. inner_prompt:sub(ref.end_pos + 1)
471+
end
472+
end
494473
end
495-
496-
return '/' .. match
497-
end)
474+
end
498475

499476
depth = depth - 1
500477
return inner_config, inner_prompt
501478
end
502479

503-
local function resolve_system_prompt(system_prompt)
504-
if type(system_prompt) == 'function' then
505-
local ok, result = pcall(system_prompt)
506-
if not ok then
507-
log.warn('Failed to resolve system prompt function: ' .. result)
508-
return nil
509-
end
510-
return result
511-
end
512-
513-
return system_prompt
514-
end
515-
516480
config = vim.tbl_deep_extend('force', M.config, config or {})
517-
config, prompt = resolve(config, prompt or '')
481+
config, prompt = resolve_prompt_template(config, prompt)
482+
prompt = prompt or ''
518483

519484
if config.system_prompt then
520-
config.system_prompt = resolve_system_prompt(config.system_prompt)
485+
if type(config.system_prompt) == 'function' then
486+
---@diagnostic disable-next-line: param-type-mismatch
487+
local ok, result = pcall(config.system_prompt)
488+
if ok then
489+
config.system_prompt = result
490+
end
491+
end
521492

522493
if M.config.prompts[config.system_prompt] then
523494
-- Name references are good for making system prompt auto sticky
@@ -547,13 +518,16 @@ function M.resolve_model(prompt, config)
547518
return model.id
548519
end, list_models())
549520

521+
local refs = prompts.parse(prompt)
550522
local selected_model = config.model or ''
551-
prompt = prompt:gsub('%$' .. WORD, function(match)
552-
if vim.tbl_contains(models, match) then
553-
selected_model = match
554-
return ''
523+
524+
prompt = prompts.replace(prompt, refs, function(ref)
525+
if ref.type == 'model' then
526+
if vim.tbl_contains(models, ref.value) then
527+
selected_model = ref.value
528+
return ''
529+
end
555530
end
556-
return '$' .. match
557531
end)
558532

559533
return selected_model, prompt

lua/CopilotChat/prompts.lua

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
local M = {}
2+
3+
local WORD = '([^%s:]+)'
4+
local WORD_NO_INPUT = '([^%s]+)'
5+
local WORD_WITH_INPUT_QUOTED = WORD .. ':`([^`]+)`'
6+
local WORD_WITH_INPUT_UNQUOTED = WORD .. ':?([^%s`]*)'
7+
8+
---@class CopilotChat.prompts.Reference
9+
---@field type 'model'|'function'|'function_call'|'resource'|'sticky'|'prompt'
10+
---@field value string
11+
---@field input? string
12+
---@field start_pos integer
13+
---@field end_pos integer
14+
15+
--- Parse all references from a prompt string, tracking positions.
16+
---@param prompt string
17+
---@return CopilotChat.prompts.Reference[] refs
18+
function M.parse(prompt)
19+
local refs = {}
20+
21+
-- $model
22+
for s, value, e in prompt:gmatch('()%$' .. WORD .. '()') do
23+
table.insert(refs, {
24+
type = 'model',
25+
value = value,
26+
start_pos = s,
27+
end_pos = e - 1,
28+
})
29+
end
30+
31+
-- @function
32+
for s, value, e in prompt:gmatch('()@' .. WORD .. '()') do
33+
table.insert(refs, {
34+
type = 'function_reference',
35+
value = value,
36+
start_pos = s,
37+
end_pos = e - 1,
38+
})
39+
end
40+
41+
-- #function_call
42+
local function function_call_matches(str)
43+
local matches = {}
44+
-- #function_call:`input` (quoted)
45+
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_QUOTED .. '()') do
46+
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
47+
end
48+
-- #function_call:input (unquoted)
49+
for s, value, input, e in str:gmatch('()#' .. WORD_WITH_INPUT_UNQUOTED .. '()') do
50+
table.insert(matches, { s = s, e = e - 1, value = value, input = input })
51+
end
52+
-- #function_call (no input)
53+
for s, value, e in str:gmatch('()#' .. WORD_NO_INPUT .. '()') do
54+
table.insert(matches, { s = s, e = e - 1, value = value, input = nil })
55+
end
56+
return matches
57+
end
58+
for _, m in ipairs(function_call_matches(prompt)) do
59+
table.insert(refs, {
60+
type = 'function_call',
61+
value = m.value,
62+
input = m.input or nil,
63+
start_pos = m.s,
64+
end_pos = m.e,
65+
})
66+
end
67+
68+
-- ##resource
69+
for s, value, e in prompt:gmatch('()##' .. WORD_NO_INPUT .. '()') do
70+
table.insert(refs, {
71+
type = 'resource',
72+
value = value,
73+
start_pos = s,
74+
end_pos = e - 1,
75+
})
76+
end
77+
78+
-- > sticky
79+
local function sticky_matches(str)
80+
local matches = {}
81+
-- > sticky (newline)
82+
for s, value, e in str:gmatch('()\n> ([^\n]+)()') do
83+
table.insert(matches, { s = s + 1, e = e - 1, value = value })
84+
end
85+
-- > sticky (start of string)
86+
for s, value, e in str:gmatch('()^> ([^\n]+)()') do
87+
table.insert(matches, { s = s, e = e - 1, value = value })
88+
end
89+
return matches
90+
end
91+
for _, m in ipairs(sticky_matches(prompt)) do
92+
table.insert(refs, {
93+
type = 'sticky',
94+
value = m.value,
95+
start_pos = m.s,
96+
end_pos = m.e,
97+
})
98+
end
99+
100+
-- /prompt
101+
for s, value, e in prompt:gmatch('()/' .. WORD_NO_INPUT .. '()') do
102+
table.insert(refs, {
103+
type = 'prompt',
104+
value = value,
105+
start_pos = s,
106+
end_pos = e - 1,
107+
})
108+
end
109+
110+
local keep = {}
111+
for i, ref in ipairs(refs) do
112+
local contained = false
113+
for j, other in ipairs(refs) do
114+
if i ~= j then
115+
-- Strictly contained
116+
if other.type ~= 'sticky' and ref.start_pos > other.start_pos and ref.end_pos < other.end_pos then
117+
contained = true
118+
break
119+
end
120+
-- Exact match, only keep the first occurrence
121+
if ref.start_pos == other.start_pos and ref.end_pos == other.end_pos and j < i then
122+
contained = true
123+
break
124+
end
125+
end
126+
end
127+
if not contained then
128+
table.insert(keep, ref)
129+
end
130+
end
131+
132+
return keep
133+
end
134+
135+
--- Replace references in the prompt using positions (descending order).
136+
---@param prompt string
137+
---@param refs CopilotChat.prompts.Reference[]
138+
---@param resolver fun(ref: CopilotChat.prompts.Reference): string?
139+
function M.replace(prompt, refs, resolver)
140+
table.sort(refs, function(a, b)
141+
return a.start_pos > b.start_pos
142+
end)
143+
for _, ref in ipairs(refs) do
144+
local output = resolver(ref)
145+
if output then
146+
prompt = prompt:sub(1, ref.start_pos - 1) .. output .. prompt:sub(ref.end_pos + 1)
147+
end
148+
end
149+
return prompt
150+
end
151+
152+
return M

0 commit comments

Comments
 (0)