Skip to content

Commit bb3c1cd

Browse files
committed
enhance: update credentials framework for oauth support
Signed-off-by: Grant Linville <[email protected]>
1 parent 418a00a commit bb3c1cd

File tree

2 files changed

+58
-35
lines changed

2 files changed

+58
-35
lines changed

pkg/credentials/credential.go

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,58 @@ import (
44
"encoding/json"
55
"fmt"
66
"strings"
7+
"time"
78

89
"github.com/docker/cli/cli/config/types"
910
)
1011

11-
const ctxSeparator = "///"
12-
1312
type CredentialType string
1413

1514
const (
15+
ctxSeparator = "///"
1616
CredentialTypeTool CredentialType = "tool"
1717
CredentialTypeModelProvider CredentialType = "modelProvider"
18+
ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL"
1819
)
1920

2021
type Credential struct {
21-
Context string `json:"context"`
22-
ToolName string `json:"toolName"`
23-
Type CredentialType `json:"type"`
24-
Env map[string]string `json:"env"`
22+
Context string `json:"context"`
23+
ToolName string `json:"toolName"`
24+
Type CredentialType `json:"type"`
25+
Env map[string]string `json:"env"`
26+
ExpiresAt *time.Time `json:"expiresAt"`
27+
RefreshToken string `json:"refreshToken"`
28+
}
29+
30+
func (c Credential) IsExpired() bool {
31+
if c.ExpiresAt == nil {
32+
return false
33+
}
34+
return time.Now().After(*c.ExpiresAt)
2535
}
2636

2737
func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
28-
env, err := json.Marshal(c.Env)
38+
cred, err := json.Marshal(c)
2939
if err != nil {
3040
return types.AuthConfig{}, err
3141
}
3242

3343
return types.AuthConfig{
34-
Username: string(c.Type),
35-
Password: string(env),
44+
Username: string(c.Type), // Username is required, but not used
45+
Password: string(cred),
3646
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
3747
}, nil
3848
}
3949

4050
func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) {
41-
var env map[string]string
42-
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
43-
return Credential{}, err
51+
var cred Credential
52+
if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil {
53+
// Legacy: try unmarshalling into just an env map
54+
var env map[string]string
55+
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
56+
return Credential{}, err
57+
}
58+
cred.Env = env
4459
}
4560

4661
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
@@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
6277
}
6378

6479
return Credential{
65-
Context: ctx,
66-
ToolName: tool,
67-
Type: CredentialType(credType),
68-
Env: env,
80+
Context: ctx,
81+
ToolName: tool,
82+
Type: CredentialType(credType),
83+
Env: cred.Env,
84+
ExpiresAt: cred.ExpiresAt,
85+
RefreshToken: cred.RefreshToken,
6986
}, nil
7087
}
7188

pkg/runner/runner.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ var (
250250
EventTypeRunFinish EventType = "runFinish"
251251
)
252252

253-
func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
253+
func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
254254
if ref.Arg == "" {
255255
return "", nil
256256
}
@@ -355,7 +355,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
355355
continue
356356
}
357357

358-
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
358+
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
359359
if err != nil {
360360
return nil, nil, err
361361
}
@@ -867,7 +867,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
867867
}
868868

869869
var (
870-
cred *credentials.Credential
870+
c *credentials.Credential
871871
exists bool
872872
)
873873

@@ -879,25 +879,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
879879
// Only try to look up the cred if the tool is on GitHub or has an alias.
880880
// If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name.
881881
if isGitHubTool(toolName) && credentialAlias == "" {
882-
cred, exists, err = r.credStore.Get(toolName)
882+
c, exists, err = r.credStore.Get(toolName)
883883
if err != nil {
884884
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err)
885885
}
886886
} else if credentialAlias != "" {
887-
cred, exists, err = r.credStore.Get(credentialAlias)
887+
c, exists, err = r.credStore.Get(credentialAlias)
888888
if err != nil {
889889
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err)
890890
}
891891
}
892892

893+
if c == nil {
894+
c = &credentials.Credential{}
895+
}
896+
893897
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
894898
// and save it in the store.
895-
if !exists {
899+
if !exists || c.IsExpired() {
896900
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
897901
if !ok || len(credToolRefs) != 1 {
898902
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
899903
}
900904

905+
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
906+
if exists && c.IsExpired() {
907+
credJson, err := json.Marshal(c)
908+
if err != nil {
909+
return nil, fmt.Errorf("failed to marshal credential: %w", err)
910+
}
911+
env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJson)))
912+
}
913+
914+
// Get the input for the credential tool, if there is any.
901915
var input string
902916
if args != nil {
903917
inputBytes, err := json.Marshal(args)
@@ -916,21 +930,13 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
916930
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
917931
}
918932

919-
var envMap struct {
920-
Env map[string]string `json:"env"`
921-
}
922-
if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil {
933+
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
923934
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
924935
}
925-
926-
cred = &credentials.Credential{
927-
Type: credentials.CredentialTypeTool,
928-
Env: envMap.Env,
929-
ToolName: credName,
930-
}
936+
c.ToolName = credName
931937

932938
isEmpty := true
933-
for _, v := range cred.Env {
939+
for _, v := range c.Env {
934940
if v != "" {
935941
isEmpty = false
936942
break
@@ -941,15 +947,15 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
941947
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
942948
if isEmpty {
943949
log.Warnf("Not saving empty credential for tool %s", toolName)
944-
} else if err := r.credStore.Add(*cred); err != nil {
950+
} else if err := r.credStore.Add(*c); err != nil {
945951
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
946952
}
947953
} else {
948954
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
949955
}
950956
}
951957

952-
for k, v := range cred.Env {
958+
for k, v := range c.Env {
953959
env = append(env, fmt.Sprintf("%s=%s", k, v))
954960
}
955961
}

0 commit comments

Comments
 (0)