From 54197ea5dd0276ed341553a620dad8680887263b Mon Sep 17 00:00:00 2001 From: Avinash Thakur Date: Thu, 27 Mar 2025 18:42:45 +0530 Subject: [PATCH 1/2] fix(copilot)!: allow overriding headers, api_base --- README.md | 8 ++--- doc/CopilotChat.txt | 8 ++--- lua/CopilotChat/client.lua | 15 +++++--- lua/CopilotChat/config/providers.lua | 54 ++++++++++++++++------------ 4 files changed, 49 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index aad53e77..4033ca69 100644 --- a/README.md +++ b/README.md @@ -417,10 +417,10 @@ Custom providers can implement these methods: embed?: string|function, -- Optional: Get extra request headers with optional expiration time - get_headers?(): table, number?, + get_headers?(self: CopilotChat.Provider): table, number?, -- Optional: Get API endpoint URL - get_url?(opts: CopilotChat.Provider.options): string, + get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string, -- Optional: Prepare request input prepare_input?(inputs: table, opts: CopilotChat.Provider.options): table, @@ -429,10 +429,10 @@ Custom providers can implement these methods: prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output, -- Optional: Get available models - get_models?(headers: table): table, + get_models?(self: CopilotChat.Provider, headers: table): table, -- Optional: Get available agents - get_agents?(headers: table): table, + get_agents?(self: CopilotChat.Provider, headers: table): table, } ``` diff --git a/doc/CopilotChat.txt b/doc/CopilotChat.txt index ca01269d..3288a10e 100644 --- a/doc/CopilotChat.txt +++ b/doc/CopilotChat.txt @@ -470,10 +470,10 @@ Custom providers can implement these methods: embed?: string|function, -- Optional: Get extra request headers with optional expiration time - get_headers?(): table, number?, + get_headers?(self: CopilotChat.Provider): table, number?, -- Optional: Get API endpoint URL - get_url?(opts: CopilotChat.Provider.options): string, + get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string, -- Optional: Prepare request input prepare_input?(inputs: table, opts: CopilotChat.Provider.options): table, @@ -482,10 +482,10 @@ Custom providers can implement these methods: prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output, -- Optional: Get available models - get_models?(headers: table): table, + get_models?(self: CopilotChat.Provider, headers: table): table, -- Optional: Get available agents - get_agents?(headers: table): table, + get_agents?(self: CopilotChat.Provider, headers: table): table, } < diff --git a/lua/CopilotChat/client.lua b/lua/CopilotChat/client.lua index 5f17080d..5e6e961c 100644 --- a/lua/CopilotChat/client.lua +++ b/lua/CopilotChat/client.lua @@ -327,7 +327,7 @@ function Client:authenticate(provider_name) local expires_at = self.provider_cache[provider_name].expires_at if provider.get_headers and (not headers or (expires_at and expires_at <= math.floor(os.time()))) then - headers, expires_at = provider.get_headers() + headers, expires_at = provider:get_headers() self.provider_cache[provider_name].headers = headers self.provider_cache[provider_name].expires_at = expires_at end @@ -354,7 +354,7 @@ function Client:fetch_models() log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) goto continue end - local ok, provider_models = pcall(provider.get_models, headers) + local ok, provider_models = pcall(provider.get_models, provider, headers) if not ok then log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models) goto continue @@ -396,7 +396,7 @@ function Client:fetch_agents() log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers) goto continue end - local ok, provider_agents = pcall(provider.get_agents, headers) + local ok, provider_agents = pcall(provider.get_agents, provider, headers) if not ok then log.warn('Failed to fetch agents from ' .. provider_name .. ': ' .. provider_agents) goto continue @@ -671,7 +671,7 @@ function Client:ask(prompt, opts) args.stream = stream_func end - local response, err = utils.curl_post(provider.get_url(options), args) + local response, err = utils.curl_post(provider:get_url(options), args) if not opts.headless then if self.current_job ~= job_id then @@ -817,7 +817,12 @@ function Client:embed(inputs, model) local success = false local attempts = 0 while not success and attempts < 5 do -- Limit total attempts to 5 - local ok, data = pcall(embed, generate_embedding_request(batch, threshold), self:authenticate(provider_name)) + local ok, data = pcall( + embed, + self.providers[models[model].provider], + generate_embedding_request(batch, threshold), + self:authenticate(provider_name) + ) if not ok then log.debug('Failed to get embeddings: ', data) diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index 1ca335e3..87130067 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -104,21 +104,24 @@ end ---@class CopilotChat.Provider ---@field disabled nil|boolean ----@field get_headers nil|fun():table,number? ----@field get_agents nil|fun(headers:table):table ----@field get_models nil|fun(headers:table):table ----@field embed nil|string|fun(inputs:table, headers:table):table +---@field api_base string +---@field default_headers table? +---@field get_headers nil|fun(self: CopilotChat.Provider):table,number? +---@field get_agents nil|fun(self: CopilotChat.Provider, headers:table):table +---@field get_models nil|fun(self: CopilotChat.Provider, headers:table):table +---@field embed nil|string|fun(self: CopilotChat.Provider, inputs:table, headers:table):table ---@field prepare_input nil|fun(inputs:table, opts:CopilotChat.Provider.options):table ---@field prepare_output nil|fun(output:table, opts:CopilotChat.Provider.options):CopilotChat.Provider.output ----@field get_url nil|fun(opts:CopilotChat.Provider.options):string +---@field get_url nil|fun(self: CopilotChat.Provider, opts:CopilotChat.Provider.options):string ---@type table local M = {} M.copilot = { embed = 'copilot_embeddings', + api_base = 'https://api.githubcopilot.com', - get_headers = function() + get_headers = function(self) local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', { json_response = true, headers = { @@ -129,18 +132,21 @@ M.copilot = { if err then error(err) end + if response.body.endpoints and response.body.endpoints.api then + self.api_base = response.body.endpoints.api + end - return { + return vim.tbl_extend('force', { ['Authorization'] = 'Bearer ' .. response.body.token, ['Editor-Version'] = EDITOR_VERSION, ['Editor-Plugin-Version'] = 'CopilotChat.nvim/*', ['Copilot-Integration-Id'] = 'vscode-chat', - }, + }, self.default_headers or {}), response.body.expires_at end, - get_agents = function(headers) - local response, err = utils.curl_get('https://api.githubcopilot.com/agents', { + get_agents = function(self, headers) + local response, err = utils.curl_get(self.api_base .. '/agents', { json_response = true, headers = headers, }) @@ -158,8 +164,8 @@ M.copilot = { end, response.body.agents) end, - get_models = function(headers) - local response, err = utils.curl_get('https://api.githubcopilot.com/models', { + get_models = function(self, headers) + local response, err = utils.curl_get(self.api_base .. '/models', { json_response = true, headers = headers, }) @@ -197,7 +203,7 @@ M.copilot = { for _, model in ipairs(models) do if not model.policy then - utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', { + utils.curl_post(self.api_base .. '/models/' .. model.id .. '/policy', { headers = headers, json_request = true, body = { state = 'enabled' }, @@ -276,27 +282,28 @@ M.copilot = { } end, - get_url = function(opts) + get_url = function(self, opts) if opts.agent then - return 'https://api.githubcopilot.com/agents/' .. opts.agent.id .. '?chat' + return self.api_base .. '/agents/' .. opts.agent.id .. '?chat' end - return 'https://api.githubcopilot.com/chat/completions' + return self.api_base .. '/chat/completions' end, } M.github_models = { embed = 'copilot_embeddings', + api_base = 'https://api.githubcopilot.com', - get_headers = function() - return { + get_headers = function(self) + return vim.tbl_extend('force', { ['Authorization'] = 'Bearer ' .. get_github_token(), ['x-ms-useragent'] = EDITOR_VERSION, ['x-ms-user-agent'] = EDITOR_VERSION, - } + }, self.default_headers or {}) end, - get_models = function(headers) + get_models = function(self, headers) local response, err = utils.curl_post('https://api.catalog.azureml.ms/asset-gallery/v1.0/models', { headers = headers, json_request = true, @@ -336,16 +343,17 @@ M.github_models = { prepare_input = M.copilot.prepare_input, prepare_output = M.copilot.prepare_output, - get_url = function() + get_url = function(self) return 'https://models.inference.ai.azure.com/chat/completions' end, } M.copilot_embeddings = { get_headers = M.copilot.get_headers, + api_base = M.copilot.api_base, - embed = function(inputs, headers) - local response, err = utils.curl_post('https://api.githubcopilot.com/embeddings', { + embed = function(self, inputs, headers) + local response, err = utils.curl_post(self.api_base .. '/embeddings', { headers = headers, json_request = true, json_response = true, From e9a33c5cbc5d8743a61c3f1c4335ce35d3ee8450 Mon Sep 17 00:00:00 2001 From: Avinash Thakur Date: Tue, 15 Apr 2025 22:41:27 +0530 Subject: [PATCH 2/2] chore: remove default_headers and api_base from public API --- lua/CopilotChat/config/providers.lua | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/lua/CopilotChat/config/providers.lua b/lua/CopilotChat/config/providers.lua index 87130067..35b19eb5 100644 --- a/lua/CopilotChat/config/providers.lua +++ b/lua/CopilotChat/config/providers.lua @@ -104,8 +104,6 @@ end ---@class CopilotChat.Provider ---@field disabled nil|boolean ----@field api_base string ----@field default_headers table? ---@field get_headers nil|fun(self: CopilotChat.Provider):table,number? ---@field get_agents nil|fun(self: CopilotChat.Provider, headers:table):table ---@field get_models nil|fun(self: CopilotChat.Provider, headers:table):table @@ -133,19 +131,21 @@ M.copilot = { error(err) end if response.body.endpoints and response.body.endpoints.api then + ---@diagnostic disable-next-line: inject-field self.api_base = response.body.endpoints.api end - return vim.tbl_extend('force', { + return { ['Authorization'] = 'Bearer ' .. response.body.token, ['Editor-Version'] = EDITOR_VERSION, ['Editor-Plugin-Version'] = 'CopilotChat.nvim/*', ['Copilot-Integration-Id'] = 'vscode-chat', - }, self.default_headers or {}), + }, response.body.expires_at end, get_agents = function(self, headers) + ---@diagnostic disable-next-line: undefined-field local response, err = utils.curl_get(self.api_base .. '/agents', { json_response = true, headers = headers, @@ -165,6 +165,7 @@ M.copilot = { end, get_models = function(self, headers) + ---@diagnostic disable-next-line: undefined-field local response, err = utils.curl_get(self.api_base .. '/models', { json_response = true, headers = headers, @@ -203,6 +204,7 @@ M.copilot = { for _, model in ipairs(models) do if not model.policy then + ---@diagnostic disable-next-line: undefined-field utils.curl_post(self.api_base .. '/models/' .. model.id .. '/policy', { headers = headers, json_request = true, @@ -284,23 +286,24 @@ M.copilot = { get_url = function(self, opts) if opts.agent then + ---@diagnostic disable-next-line: undefined-field return self.api_base .. '/agents/' .. opts.agent.id .. '?chat' end + ---@diagnostic disable-next-line: undefined-field return self.api_base .. '/chat/completions' end, } M.github_models = { embed = 'copilot_embeddings', - api_base = 'https://api.githubcopilot.com', get_headers = function(self) - return vim.tbl_extend('force', { + return { ['Authorization'] = 'Bearer ' .. get_github_token(), ['x-ms-useragent'] = EDITOR_VERSION, ['x-ms-user-agent'] = EDITOR_VERSION, - }, self.default_headers or {}) + } end, get_models = function(self, headers) @@ -350,9 +353,11 @@ M.github_models = { M.copilot_embeddings = { get_headers = M.copilot.get_headers, + ---@diagnostic disable-next-line: undefined-field api_base = M.copilot.api_base, embed = function(self, inputs, headers) + ---@diagnostic disable-next-line: undefined-field local response, err = utils.curl_post(self.api_base .. '/embeddings', { headers = headers, json_request = true,