Skip to content

fix(copilot)!: allow overriding headers, api_base #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,10 @@ Custom providers can implement these methods:
embed?: string|function,

-- Optional: Get extra request headers with optional expiration time
get_headers?(): table<string,string>, number?,
get_headers?(self: CopilotChat.Provider): table<string,string>, number?,

-- Optional: Get API endpoint URL
get_url?(opts: CopilotChat.Provider.options): string,
get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string,

-- Optional: Prepare request input
prepare_input?(inputs: table<CopilotChat.Provider.input>, opts: CopilotChat.Provider.options): table,
Expand All @@ -429,10 +429,10 @@ Custom providers can implement these methods:
prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output,

-- Optional: Get available models
get_models?(headers: table): table<CopilotChat.Provider.model>,
get_models?(self: CopilotChat.Provider, headers: table): table<CopilotChat.Provider.model>,

-- Optional: Get available agents
get_agents?(headers: table): table<CopilotChat.Provider.agent>,
get_agents?(self: CopilotChat.Provider, headers: table): table<CopilotChat.Provider.agent>,
}
```

Expand Down
8 changes: 4 additions & 4 deletions doc/CopilotChat.txt
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,10 @@ Custom providers can implement these methods:
embed?: string|function,

-- Optional: Get extra request headers with optional expiration time
get_headers?(): table<string,string>, number?,
get_headers?(self: CopilotChat.Provider): table<string,string>, number?,

-- Optional: Get API endpoint URL
get_url?(opts: CopilotChat.Provider.options): string,
get_url?(self: CopilotChat.Provider, opts: CopilotChat.Provider.options): string,

-- Optional: Prepare request input
prepare_input?(inputs: table<CopilotChat.Provider.input>, opts: CopilotChat.Provider.options): table,
Expand All @@ -482,10 +482,10 @@ Custom providers can implement these methods:
prepare_output?(output: table, opts: CopilotChat.Provider.options): CopilotChat.Provider.output,

-- Optional: Get available models
get_models?(headers: table): table<CopilotChat.Provider.model>,
get_models?(self: CopilotChat.Provider, headers: table): table<CopilotChat.Provider.model>,

-- Optional: Get available agents
get_agents?(headers: table): table<CopilotChat.Provider.agent>,
get_agents?(self: CopilotChat.Provider, headers: table): table<CopilotChat.Provider.agent>,
}
<

Expand Down
15 changes: 10 additions & 5 deletions lua/CopilotChat/client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ function Client:authenticate(provider_name)
local expires_at = self.provider_cache[provider_name].expires_at

if provider.get_headers and (not headers or (expires_at and expires_at <= math.floor(os.time()))) then
headers, expires_at = provider.get_headers()
headers, expires_at = provider:get_headers()
self.provider_cache[provider_name].headers = headers
self.provider_cache[provider_name].expires_at = expires_at
end
Expand All @@ -354,7 +354,7 @@ function Client:fetch_models()
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
goto continue
end
local ok, provider_models = pcall(provider.get_models, headers)
local ok, provider_models = pcall(provider.get_models, provider, headers)
if not ok then
log.warn('Failed to fetch models from ' .. provider_name .. ': ' .. provider_models)
goto continue
Expand Down Expand Up @@ -396,7 +396,7 @@ function Client:fetch_agents()
log.warn('Failed to authenticate with ' .. provider_name .. ': ' .. headers)
goto continue
end
local ok, provider_agents = pcall(provider.get_agents, headers)
local ok, provider_agents = pcall(provider.get_agents, provider, headers)
if not ok then
log.warn('Failed to fetch agents from ' .. provider_name .. ': ' .. provider_agents)
goto continue
Expand Down Expand Up @@ -671,7 +671,7 @@ function Client:ask(prompt, opts)
args.stream = stream_func
end

local response, err = utils.curl_post(provider.get_url(options), args)
local response, err = utils.curl_post(provider:get_url(options), args)

if not opts.headless then
if self.current_job ~= job_id then
Expand Down Expand Up @@ -815,7 +815,12 @@ function Client:embed(inputs, model)
local success = false
local attempts = 0
while not success and attempts < 5 do -- Limit total attempts to 5
local ok, data = pcall(embed, generate_embedding_request(batch, threshold), self:authenticate(provider_name))
local ok, data = pcall(
embed,
self.providers[models[model].provider],
generate_embedding_request(batch, threshold),
self:authenticate(provider_name)
)

if not ok then
log.debug('Failed to get embeddings: ', data)
Expand Down
51 changes: 32 additions & 19 deletions lua/CopilotChat/config/providers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,22 @@ end

---@class CopilotChat.Provider
---@field disabled nil|boolean
---@field get_headers nil|fun():table<string, string>,number?
---@field get_agents nil|fun(headers:table):table<CopilotChat.Provider.agent>
---@field get_models nil|fun(headers:table):table<CopilotChat.Provider.model>
---@field embed nil|string|fun(inputs:table<string>, headers:table):table<CopilotChat.Provider.embed>
---@field get_headers nil|fun(self: CopilotChat.Provider):table<string, string>,number?
---@field get_agents nil|fun(self: CopilotChat.Provider, headers:table):table<CopilotChat.Provider.agent>
---@field get_models nil|fun(self: CopilotChat.Provider, headers:table):table<CopilotChat.Provider.model>
---@field embed nil|string|fun(self: CopilotChat.Provider, inputs:table<string>, headers:table):table<CopilotChat.Provider.embed>
---@field prepare_input nil|fun(inputs:table<CopilotChat.Provider.input>, opts:CopilotChat.Provider.options):table
---@field prepare_output nil|fun(output:table, opts:CopilotChat.Provider.options):CopilotChat.Provider.output
---@field get_url nil|fun(opts:CopilotChat.Provider.options):string
---@field get_url nil|fun(self: CopilotChat.Provider, opts:CopilotChat.Provider.options):string

---@type table<string, CopilotChat.Provider>
local M = {}

M.copilot = {
embed = 'copilot_embeddings',
api_base = 'https://api.githubcopilot.com',

get_headers = function()
get_headers = function(self)
local response, err = utils.curl_get('https://api.github.com/copilot_internal/v2/token', {
json_response = true,
headers = {
Expand All @@ -129,6 +130,10 @@ M.copilot = {
if err then
error(err)
end
if response.body.endpoints and response.body.endpoints.api then
---@diagnostic disable-next-line: inject-field
self.api_base = response.body.endpoints.api
end

return {
['Authorization'] = 'Bearer ' .. response.body.token,
Expand All @@ -139,8 +144,9 @@ M.copilot = {
response.body.expires_at
end,

get_agents = function(headers)
local response, err = utils.curl_get('https://api.githubcopilot.com/agents', {
get_agents = function(self, headers)
---@diagnostic disable-next-line: undefined-field
local response, err = utils.curl_get(self.api_base .. '/agents', {
json_response = true,
headers = headers,
})
Expand All @@ -158,8 +164,9 @@ M.copilot = {
end, response.body.agents)
end,

get_models = function(headers)
local response, err = utils.curl_get('https://api.githubcopilot.com/models', {
get_models = function(self, headers)
---@diagnostic disable-next-line: undefined-field
local response, err = utils.curl_get(self.api_base .. '/models', {
json_response = true,
headers = headers,
})
Expand Down Expand Up @@ -197,7 +204,8 @@ M.copilot = {

for _, model in ipairs(models) do
if not model.policy then
utils.curl_post('https://api.githubcopilot.com/models/' .. model.id .. '/policy', {
---@diagnostic disable-next-line: undefined-field
utils.curl_post(self.api_base .. '/models/' .. model.id .. '/policy', {
headers = headers,
json_request = true,
body = { state = 'enabled' },
Expand Down Expand Up @@ -276,27 +284,29 @@ M.copilot = {
}
end,

get_url = function(opts)
get_url = function(self, opts)
if opts.agent then
return 'https://api.githubcopilot.com/agents/' .. opts.agent.id .. '?chat'
---@diagnostic disable-next-line: undefined-field
return self.api_base .. '/agents/' .. opts.agent.id .. '?chat'
end

return 'https://api.githubcopilot.com/chat/completions'
---@diagnostic disable-next-line: undefined-field
return self.api_base .. '/chat/completions'
end,
}

M.github_models = {
embed = 'copilot_embeddings',

get_headers = function()
get_headers = function(self)
return {
['Authorization'] = 'Bearer ' .. get_github_token(),
['x-ms-useragent'] = EDITOR_VERSION,
['x-ms-user-agent'] = EDITOR_VERSION,
}
end,

get_models = function(headers)
get_models = function(self, headers)
local response, err = utils.curl_post('https://api.catalog.azureml.ms/asset-gallery/v1.0/models', {
headers = headers,
json_request = true,
Expand Down Expand Up @@ -344,16 +354,19 @@ M.github_models = {
prepare_input = M.copilot.prepare_input,
prepare_output = M.copilot.prepare_output,

get_url = function()
get_url = function(self)
return 'https://models.inference.ai.azure.com/chat/completions'
end,
}

M.copilot_embeddings = {
get_headers = M.copilot.get_headers,
---@diagnostic disable-next-line: undefined-field
api_base = M.copilot.api_base,

embed = function(inputs, headers)
local response, err = utils.curl_post('https://api.githubcopilot.com/embeddings', {
embed = function(self, inputs, headers)
---@diagnostic disable-next-line: undefined-field
local response, err = utils.curl_post(self.api_base .. '/embeddings', {
headers = headers,
json_request = true,
json_response = true,
Expand Down
Loading