From 40360fb186645d9365bad4341bec272cfbd9aeec Mon Sep 17 00:00:00 2001 From: Tomas Slusny Date: Tue, 18 Feb 2025 19:50:22 +0100 Subject: [PATCH] feat: Add github workspace command Supports resolving github code index workspace data and searching in it TODO: Currently this api do not accepts ghu_ github copilot token, and I need to use `gh cli` instead that creates hosts.yml https://github.blog/engineering/the-technology-behind-githubs-new-code-search/ Signed-off-by: Tomas Slusny --- lua/CopilotChat/client.lua | 18 +++++ lua/CopilotChat/config/contexts.lua | 8 +- lua/CopilotChat/config/mappings.lua | 3 +- lua/CopilotChat/config/providers.lua | 108 +++++++++++++++++++++++++++ lua/CopilotChat/context.lua | 14 ++++ lua/CopilotChat/init.lua | 7 +- 6 files changed, 153 insertions(+), 5 deletions(-) diff --git a/lua/CopilotChat/client.lua b/lua/CopilotChat/client.lua index 6c709f93..a0d5d6e1 100644 --- a/lua/CopilotChat/client.lua +++ b/lua/CopilotChat/client.lua @@ -809,6 +809,24 @@ function Client:embed(inputs, model) return results end +--- Search for the given query +---@param query string: The query to search for +---@param repository string: The repository to search in +---@param model string: The model to use for search +---@return table +function Client:search(query, repository, model) + local models = self:fetch_models() + + local provider_name, search = resolve_provider_function('search', model, models, self.providers) + local headers = self:authenticate(provider_name) + local ok, response = pcall(search, query, repository, headers) + if not ok then + log.warn('Failed to search: ', response) + return {} + end + return response +end + --- Stop the running job ---@return boolean function Client:stop() diff --git a/lua/CopilotChat/config/contexts.lua b/lua/CopilotChat/config/contexts.lua index b137f53c..a6e4d2cf 100644 --- a/lua/CopilotChat/config/contexts.lua +++ b/lua/CopilotChat/config/contexts.lua @@ -5,7 +5,7 @@ local utils = require('CopilotChat.utils') ---@class CopilotChat.config.context ---@field description string? ---@field input fun(callback: fun(input: string?), source: CopilotChat.source)? ----@field resolve fun(input: string?, source: CopilotChat.source, prompt: string):table +---@field resolve fun(input: string?, source: CopilotChat.source, prompt: string, model: string):table ---@type table return { @@ -173,4 +173,10 @@ return { return context.quickfix() end, }, + workspace = { + description = 'Includes all non-hidden files in the current workspace in chat context.', + resolve = function(_, _, prompt, model) + return context.workspace(prompt, model) + end, + }, } diff --git a/lua/CopilotChat/config/mappings.lua b/lua/CopilotChat/config/mappings.lua index c04895b5..841446d3 100644 --- a/lua/CopilotChat/config/mappings.lua +++ b/lua/CopilotChat/config/mappings.lua @@ -409,7 +409,8 @@ return { async.run(function() local embeddings = {} if section and not section.answer then - embeddings = copilot.resolve_embeddings(section.content, chat.config) + local _, selected_model = pcall(copilot.resolve_model, section.content, chat.config) + embeddings = copilot.resolve_embeddings(section.content, selected_model, chat.config) end for _, embedding in ipairs(embeddings) do diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index b51b307e..3114f285 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -43,6 +43,7 @@ local utils = require('CopilotChat.utils') ---@field get_agents nil|fun(headers:table):table ---@field get_models nil|fun(headers:table):table ---@field embed nil|string|fun(inputs:table, headers:table):table +---@field search nil|string|fun(query:string, repository:string, headers:table):table ---@field prepare_input nil|fun(inputs:table, opts:CopilotChat.Provider.options):table ---@field prepare_output nil|fun(output:table, opts:CopilotChat.Provider.options):CopilotChat.Provider.output ---@field get_url nil|fun(opts:CopilotChat.Provider.options):string @@ -100,11 +101,41 @@ local function get_github_token() error('Failed to find GitHub token') end +local cached_gh_apps_token = nil + +--- Get the github apps token (gho_ token) +---@return string +local function get_gh_apps_token() + if cached_gh_apps_token then + return cached_gh_apps_token + end + + async.util.scheduler() + + local config_path = utils.config_path() + if not config_path then + error('Failed to find config path for GitHub token') + end + + local file_path = config_path .. '/gh/hosts.yml' + if vim.fn.filereadable(file_path) == 1 then + local content = table.concat(vim.fn.readfile(file_path), '\n') + local token = content:match('oauth_token:%s*([%w_]+)') + if token then + cached_gh_apps_token = token + return token + end + end + + error('Failed to find GitHub token') +end + ---@type table local M = {} M.copilot = { embed = 'copilot_embeddings', + search = 'copilot_search', get_headers = function() local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', { @@ -271,6 +302,7 @@ M.copilot = { M.github_models = { embed = 'copilot_embeddings', + search = 'copilot_search', get_headers = function() return { @@ -350,4 +382,80 @@ M.copilot_embeddings = { end, } +M.copilot_search = { + get_headers = M.copilot.get_headers, + + get_token = function() + return get_gh_apps_token(), nil + end, + + search = function(query, repository, headers) + utils.curl_post( + 'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index', + { + headers = headers, + } + ) + + local response, err = utils.curl_get( + 'https://api.github.com/repos/' .. repository .. '/copilot_internal/embeddings_index', + { + headers = headers, + } + ) + + if err then + error(err) + end + + if response.status ~= 200 then + error('Failed to check search: ' .. tostring(response.status)) + end + + local body = vim.json.decode(response.body) + + if + body.can_index ~= 'ok' + or not body.bm25_search_ok + or not body.lexical_search_ok + or not body.semantic_code_search_ok + or not body.semantic_doc_search_ok + or not body.semantic_indexing_enabled + then + error('Failed to search: ' .. vim.inspect(body)) + end + + local body = vim.json.encode({ + query = query, + scopingQuery = '(repo:' .. repository .. ')', + similarity = 0.766, + limit = 100, + }) + + local response, err = utils.curl_post('https://api.individual.githubcopilot.com/search/code', { + headers = headers, + body = utils.temp_file(body), + }) + + if err then + error(err) + end + + if response.status ~= 200 then + error('Failed to search: ' .. tostring(response.body)) + end + + local out = {} + for _, result in ipairs(vim.json.decode(response.body)) do + table.insert(out, { + filename = result.path, + filetype = result.languageName:lower(), + score = result.score, + content = result.contents, + }) + end + return out + end, +} + return M diff --git a/lua/CopilotChat/context.lua b/lua/CopilotChat/context.lua index 22a7a282..c260c8a9 100644 --- a/lua/CopilotChat/context.lua +++ b/lua/CopilotChat/context.lua @@ -647,6 +647,20 @@ function M.quickfix() return out end +--- Get the content of the current workspace +---@param prompt string +---@param model string +function M.workspace(prompt, model) + local git_remote = + vim.trim(utils.system({ 'git', 'config', '--get', 'remote.origin.url' }).stdout) + local repo_path = git_remote:match('github.com[:/](.+).git$') + if not repo_path then + error('Could not determine GitHub repository from git remote: ' .. git_remote) + end + + return client:search(prompt, repo_path, model) +end + --- Filter embeddings based on the query ---@param prompt string ---@param model string diff --git a/lua/CopilotChat/init.lua b/lua/CopilotChat/init.lua index bb0bcc7d..305ce085 100644 --- a/lua/CopilotChat/init.lua +++ b/lua/CopilotChat/init.lua @@ -247,9 +247,10 @@ end --- Resolve the embeddings from the prompt. ---@param prompt string +---@param model string ---@param config CopilotChat.config.shared ---@return table, string -function M.resolve_embeddings(prompt, config) +function M.resolve_embeddings(prompt, model, config) local contexts = {} local function parse_context(prompt_context) local split = vim.split(prompt_context, ':') @@ -289,7 +290,7 @@ function M.resolve_embeddings(prompt, config) for _, context_data in ipairs(contexts) do local context_value = M.config.contexts[context_data.name] for _, embedding in - ipairs(context_value.resolve(context_data.input, state.source or {}, prompt)) + ipairs(context_value.resolve(context_data.input, state.source or {}, prompt, model)) do if embedding then embeddings:set(embedding.filename, embedding) @@ -672,7 +673,7 @@ function M.ask(prompt, config) local ok, err = pcall(async.run, function() local selected_agent, prompt = M.resolve_agent(prompt, config) local selected_model, prompt = M.resolve_model(prompt, config) - local embeddings, prompt = M.resolve_embeddings(prompt, config) + local embeddings, prompt = M.resolve_embeddings(prompt, selected_model, config) local has_output = false local query_ok, filtered_embeddings =