Skip to content

Commit 67739ec

Browse files
committed
fix listing stacked credentials
Signed-off-by: Grant Linville <[email protected]>
1 parent eb63961 commit 67739ec

File tree

2 files changed

+117
-13
lines changed

2 files changed

+117
-13
lines changed

pkg/credentials/dbstore.go

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/adrg/xdg"
1515
"github.com/glebarez/sqlite"
1616
"github.com/gptscript-ai/gptscript/pkg/config"
17+
"golang.org/x/exp/maps"
1718
"gorm.io/gorm"
1819
"gorm.io/gorm/logger"
1920
"k8s.io/apimachinery/pkg/runtime/schema"
@@ -302,28 +303,59 @@ func (d *DBStore) Remove(ctx context.Context, toolName string) error {
302303
}
303304

304305
func (d *DBStore) List(ctx context.Context) ([]Credential, error) {
305-
var (
306-
dbCreds []GptscriptCredential
307-
err error
308-
)
309-
if err = d.db.WithContext(ctx).Where("context = ?", first(d.credCtxs)).Find(&dbCreds).Error; err != nil {
306+
if first(d.credCtxs) == AllCredentialContexts {
307+
return d.listAll(ctx)
308+
}
309+
310+
credsByContext := make(map[string][]GptscriptCredential)
311+
for _, credCtx := range d.credCtxs {
312+
var creds []GptscriptCredential
313+
if err := d.db.WithContext(ctx).Where("context = ?", credCtx).Find(&creds).Error; err != nil {
314+
return nil, fmt.Errorf("failed to list credentials: %w", err)
315+
}
316+
credsByContext[credCtx] = creds
317+
}
318+
319+
// Go through the contexts in reverse order so that higher priority contexts override lower ones.
320+
credsByName := make(map[string]Credential)
321+
for i := len(d.credCtxs) - 1; i >= 0; i-- {
322+
for _, dbCred := range credsByContext[d.credCtxs[i]] {
323+
dbCred, err := d.decryptCred(ctx, dbCred)
324+
if err != nil {
325+
return nil, fmt.Errorf("failed to decrypt credential: %w", err)
326+
}
327+
328+
cred, err := dbCredToCred(dbCred)
329+
if err != nil {
330+
return nil, fmt.Errorf("failed to convert GptscriptCredential to Credential: %w", err)
331+
}
332+
333+
credsByName[cred.ToolName] = cred
334+
}
335+
}
336+
337+
return maps.Values(credsByName), nil
338+
}
339+
340+
func (d *DBStore) listAll(ctx context.Context) ([]Credential, error) {
341+
var allCreds []GptscriptCredential
342+
if err := d.db.WithContext(ctx).Find(&allCreds).Error; err != nil {
310343
return nil, fmt.Errorf("failed to list credentials: %w", err)
311344
}
312345

313-
var credentials []Credential
314-
for _, dbCred := range dbCreds {
315-
dbCred, err = d.decryptCred(ctx, dbCred)
346+
var creds []Credential
347+
for _, dbCred := range allCreds {
348+
dbCred, err := d.decryptCred(ctx, dbCred)
316349
if err != nil {
317350
return nil, fmt.Errorf("failed to decrypt credential: %w", err)
318351
}
319352

320-
credential, err := dbCredToCred(dbCred)
353+
cred, err := dbCredToCred(dbCred)
321354
if err != nil {
322355
return nil, fmt.Errorf("failed to convert GptscriptCredential to Credential: %w", err)
323356
}
324-
325-
credentials = append(credentials, credential)
357+
creds = append(creds, cred)
326358
}
327359

328-
return credentials, nil
360+
return creds, nil
329361
}

pkg/credentials/dbstore_test.go

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ func TestDBStore(t *testing.T) {
2525
RefreshToken: "myrefreshtoken",
2626
}
2727

28-
cfg, _ := config.ReadCLIConfig("")
28+
cfg, err := config.ReadCLIConfig("")
29+
require.NoError(t, err)
2930

3031
// Set up the store
3132
store, err := NewDBStore(context.Background(), cfg, []string{credCtx})
@@ -60,3 +61,74 @@ func TestDBStore(t *testing.T) {
6061
// Delete the credential
6162
require.NoError(t, store.Remove(context.Background(), credential.ToolName))
6263
}
64+
65+
func TestDBStoreStackedContexts(t *testing.T) {
66+
const (
67+
credCtx1 = "testing1"
68+
credCtx2 = "testing2"
69+
)
70+
71+
bytes := make([]byte, 16)
72+
_, err := rand.Read(bytes)
73+
require.NoError(t, err)
74+
75+
credential1 := Credential{
76+
Context: credCtx1,
77+
ToolName: fmt.Sprintf("%x", bytes),
78+
Type: CredentialTypeTool,
79+
Env: map[string]string{"ENV_VAR": "value"},
80+
}
81+
82+
credential2 := Credential{
83+
Context: credCtx2,
84+
ToolName: fmt.Sprintf("%x", bytes),
85+
Type: CredentialTypeTool,
86+
Env: map[string]string{"ENV_VAR": "value"},
87+
}
88+
89+
cfg, err := config.ReadCLIConfig("")
90+
require.NoError(t, err)
91+
92+
// Set up the stores
93+
store1, err := NewDBStore(context.Background(), cfg, []string{credCtx1})
94+
require.NoError(t, err)
95+
store2, err := NewDBStore(context.Background(), cfg, []string{credCtx2})
96+
require.NoError(t, err)
97+
98+
// Create both credentials
99+
require.NoError(t, store1.Add(context.Background(), credential1))
100+
require.NoError(t, store2.Add(context.Background(), credential2))
101+
102+
// Set up a store with both contexts
103+
storeBoth, err := NewDBStore(context.Background(), cfg, []string{credCtx1, credCtx2})
104+
require.NoError(t, err)
105+
106+
// Get the credential. We should get credential1.
107+
cred, found, err := storeBoth.Get(context.Background(), credential1.ToolName)
108+
require.NoError(t, err)
109+
require.True(t, found)
110+
require.Equal(t, credential1.ToolName, cred.ToolName)
111+
require.Equal(t, credential1.Context, cred.Context)
112+
require.Equal(t, credential1.Env, cred.Env)
113+
114+
// List credentials. We should only get credential1.
115+
list, err := storeBoth.List(context.Background())
116+
require.NoError(t, err)
117+
118+
found = false
119+
for _, c := range list {
120+
if c.ToolName == credential1.ToolName {
121+
require.Equal(t, credential1.Env, c.Env)
122+
require.Equal(t, credential1.Context, c.Context)
123+
found = true
124+
break
125+
} else {
126+
require.Fail(t, "unexpected credential found")
127+
}
128+
}
129+
require.True(t, found)
130+
131+
// Delete both credentials
132+
require.NoError(t, store1.Remove(context.Background(), credential1.ToolName))
133+
require.NoError(t, store2.Remove(context.Background(), credential2.ToolName))
134+
}

0 commit comments

Comments
 (0)