Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ When you use `@copilot`, the LLM can call functions like `glob`, `file`, `gitdif
| - | `gh` | Show help message |

> [!WARNING]
> Some plugins (e.g. `copilot.vim`) may also map common keys like `<Tab>` in insert mode.
> Some plugins (e.g. `copilot.vim`) may also map common keys like `<Tab>` in insert mode.
> To avoid conflicts, disable Copilot's default `<Tab>` mapping with:
>
> ```lua
Expand Down Expand Up @@ -404,6 +404,21 @@ Add custom AI providers:
- `copilot` - GitHub Copilot (default)
- `github_models` - GitHub Marketplace models (disabled by default)

## Github Enterprise

If your employer provides access to Copilot via a Github Enterprise instance ("GHEC") you can provide the respective URLs with the following config keys:

```lua
{
-- github instance main address w/o protocol prefix, default: "github.com" (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
github_instance_url = 'mycorp.ghe.com',
-- github instance api address w/o protocol prefix, default: "api.github.com" (without "https://"). E.g.: "api.mycorp.ghe.com"
github_instance_api_url = 'api.mycorp.ghe.com',
}
```

(These keys are used in the default Copilot "provider", this is an alternative to defining a full custom provider)

# API Reference

## Core
Expand Down
12 changes: 7 additions & 5 deletions lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -250,19 +250,21 @@ function Client:models()
ipairs(get_cached(self.provider_cache[provider_name], 'models', function()
notify.publish(notify.STATUS, 'Fetching models from ' .. provider_name)

local ok, headers = pcall(self.authenticate, self, provider_name)
local ok, headers_or_err = pcall(self.authenticate, self, provider_name)
if not ok then
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
log.error('Failed to authenticate with ' .. provider_name .. ': ' .. headers_or_err)
error(headers_or_err)
return {}
end

local ok, models = pcall(provider.get_models, headers)
local ok, models_or_err = pcall(provider.get_models, headers_or_err)
if not ok then
log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. models)
log.error('Failed to fetch models from ' .. provider_name .. ': ' .. models_or_err)
error(models_or_err)
return {}
end

return models or {}
return models_or_err or {}
end))
do
model.provider = provider_name
Expand Down
5 changes: 5 additions & 0 deletions lua/CopilotChat/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
---@field functions table<string, CopilotChat.config.functions.Function>?
---@field prompts table<string, CopilotChat.config.prompts.Prompt|string>?
---@field mappings CopilotChat.config.mappings?
---@field github_instance_url string?
---@field github_instance_api_url string?
return {

-- Shared config starts here (can be passed to functions at runtime and configured via setup function)
Expand Down Expand Up @@ -102,6 +104,9 @@ return {

chat_autocomplete = true, -- Enable chat autocompletion (when disabled, requires manual `mappings.complete` trigger)

github_instance_url = 'github.com', -- github instance main address w/o protocol prefix (without "https://"). E.g. a github-enterprise address might look like this: "mycorp.ghe.com"
github_instance_api_url = 'api.github.com', -- github instance api address w/o protocol prefix (without "https://"). E.g.: "api.mycorp.ghe.com"

log_path = vim.fn.stdpath('state') .. '/CopilotChat.log', -- Default path to log file
history_path = vim.fn.stdpath('data') .. '/copilotchat_history', -- Default path to stored history

Expand Down
55 changes: 44 additions & 11 deletions lua/CopilotChat/config/providers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ local constants = require('CopilotChat.constants')
local notify = require('CopilotChat.notify')
local utils = require('CopilotChat.utils')
local plenary_utils = require('plenary.async.util')
local log = require('plenary.log')

local EDITOR_VERSION = 'Neovim/' .. vim.version().major .. '.' .. vim.version().minor .. '.' .. vim.version().patch

---@class CopilotChat
---@field config CopilotChat.config.Config
---@field chat CopilotChat.ui.chat.Chat
local MC = setmetatable({}, {
__index = function(t, key)
if key == 'config' then
return require('CopilotChat.config')
end
return rawget(t, key)
end,
})

local token_cache = nil
local unsaved_token_cache = {}
local function load_tokens()
Expand Down Expand Up @@ -50,7 +63,7 @@ end
---@return string
local function github_device_flow(tag, client_id, scope)
local function request_device_code()
local res = utils.curl_post('https://github.com/login/device/code', {
local res = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/device/code', {
body = {
client_id = client_id,
scope = scope,
Expand All @@ -66,7 +79,7 @@ local function github_device_flow(tag, client_id, scope)
while true do
plenary_utils.sleep(interval * 1000)

local res = utils.curl_post('https://github.com/login/oauth/access_token', {
local res = utils.curl_post('https://' .. MC.config.github_instance_url .. '/login/oauth/access_token', {
body = {
client_id = client_id,
device_code = device_code,
Expand Down Expand Up @@ -146,7 +159,7 @@ local function get_github_copilot_token(tag)
local parsed_data = utils.json_decode(file_data)
if parsed_data then
for key, value in pairs(parsed_data) do
if string.find(key, 'github.com') and value and value.oauth_token then
if string.find(key, MC.config.github_instance_url) and value and value.oauth_token then
return set_token(tag, value.oauth_token, false)
end
end
Expand All @@ -173,7 +186,7 @@ local function get_github_models_token(tag)

-- loading token from gh cli if available
if vim.fn.executable('gh') == 0 then
local result = utils.system({ 'gh', 'auth', 'token', '-h', 'github.com' })
local result = utils.system({ 'gh', 'auth', 'token', '-h', MC.config.github_instance_url })
if result and result.code == 0 and result.stdout then
local gh_token = vim.trim(result.stdout)
if gh_token ~= '' and not gh_token:find('no oauth token') then
Expand Down Expand Up @@ -205,23 +218,42 @@ end
---@field prepare_input nil|fun(inputs:table<CopilotChat.client.Message>, opts:CopilotChat.config.providers.Options):table
---@field prepare_output nil|fun(output:table, opts:CopilotChat.config.providers.Options):CopilotChat.config.providers.Output
---@field get_url nil|fun(opts:CopilotChat.config.providers.Options):string
---@field endpoints_api string?

---@type table<string, CopilotChat.config.providers.Provider>
local M = {}

M.copilot = {
endpoints_api = '',

get_headers = function()
local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', {
local url = 'https://' .. MC.config.github_instance_api_url .. '/copilot_internal/v2/token'
log.debug('get headers - get ' .. url)
local response, err = utils.curl_get(url, {
json_response = true,
headers = {
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_api_url),
},
})

if err then
error(err)
end

if response.body and response.body.endpoints and response.body.endpoints.api then
log.info('get_headers ok, authenticated. Use api endpoint: ' .. response.body.endpoints.api)
M.endpoints_api = response.body.endpoints.api
else
log.error(
'get_headers authenticated, but missing key "endpoints.api" in server response. response: '
.. utils.to_string(response)
)
error(
'get_headers authenticated, but missing key "endpoints.api" in server response. Check log for details: '
.. MC.config.log_path
)
end

return {
['Authorization'] = 'Bearer ' .. response.body.token,
['Editor-Version'] = EDITOR_VERSION,
Expand All @@ -232,10 +264,10 @@ M.copilot = {
end,

get_info = function(headers)
local response, err = utils.curl_get('https://api.github.com/copilot_internal/user', {
local response, err = utils.curl_get('https://' .. MC.config.github_instance_url .. '/copilot_internal/user', {
json_response = true,
headers = {
['Authorization'] = 'Token ' .. get_github_copilot_token('github_copilot'),
['Authorization'] = 'Token ' .. get_github_copilot_token(MC.config.github_instance_url),
},
})

Expand Down Expand Up @@ -282,7 +314,8 @@ M.copilot = {
end,

get_models = function(headers)
local response, err = utils.curl_get('https://api.githubcopilot.com/models', {
log.info('getting models .. headers: ' .. utils.to_string(headers))
local response, err = utils.curl_get(M.endpoints_api .. '/models', {
json_response = true,
headers = headers,
})
Expand Down Expand Up @@ -322,7 +355,7 @@ M.copilot = {

for _, model in ipairs(models) do
if not model.policy then
utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', {
utils.curl_post(M.endpoints_api .. '/models/' .. model.id .. '/policy', {
headers = headers,
json_request = true,
body = { state = 'enabled' },
Expand Down Expand Up @@ -448,7 +481,7 @@ M.copilot = {
end,

get_url = function()
return 'https://api.githubcopilot.com/chat/completions'
return M.endpoints_api .. '/chat/completions'
end,
}

Expand Down
38 changes: 38 additions & 0 deletions lua/CopilotChat/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,44 @@ M.curl_post = async.wrap(function(url, opts, callback)
curl.post(url, args)
end, 3)

function M.to_string(tbl)
-- credit: http://lua-users.org/wiki/TableSerialization (universal tostring)
local function table_print(tt, indent, done)
done = done or {}
indent = indent or 0
if type(tt) == 'table' then
local sb = {}
for key, value in pairs(tt) do
table.insert(sb, string.rep(' ', indent)) -- indent it
if type(value) == 'table' and not done[value] then
done[value] = true
table.insert(sb, key .. ' = {\n')
table.insert(sb, table_print(value, indent + 2, done))
table.insert(sb, string.rep(' ', indent)) -- indent it
table.insert(sb, '}\n')
elseif 'number' == type(key) then
table.insert(sb, string.format('"%s"\n', tostring(value)))
else
table.insert(sb, string.format('%s = "%s"\n', tostring(key), tostring(value)))
end
end
return table.concat(sb)
else
return tt .. '\n'
end
end

if 'nil' == type(tbl) then
return tostring(nil)
elseif 'table' == type(tbl) then
return table_print(tbl)
elseif 'string' == type(tbl) then
return tbl
else
return tostring(tbl)
end
end

local function filter_files(files, max_count)
local filetype = require('plenary.filetype')

Expand Down