From 36b136039326acc4da56bed5799ecc2b07837d56 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 8 Nov 2024 14:40:13 -0500 Subject: [PATCH] fix: add default model to the loader The tool loader will set the models on the tools if none is set. The way that that happens works for the CLI, but is not compatible with the SDK. This change makes the default model logic work with the SDK server. Signed-off-by: Donnie Adams --- pkg/builtin/builtin.go | 5 +++++ pkg/loader/loader.go | 44 +++++++++++++++++++++++--------------- pkg/loader/openapi_test.go | 18 ++++++++-------- pkg/sdkserver/run.go | 6 +++++- 4 files changed, 46 insertions(+), 27 deletions(-) diff --git a/pkg/builtin/builtin.go b/pkg/builtin/builtin.go index 90f36919..42ff373b 100644 --- a/pkg/builtin/builtin.go +++ b/pkg/builtin/builtin.go @@ -277,10 +277,15 @@ func ListTools() (result []types.Tool) { } func Builtin(name string) (types.Tool, bool) { + return BuiltinWithDefaultModel(name, "") +} + +func BuiltinWithDefaultModel(name, defaultModel string) (types.Tool, bool) { // Legacy syntax not used anymore name = strings.TrimSuffix(name, "?") t, ok := tools[name] t.Parameters.Name = name + t.Parameters.ModelName = defaultModel t.ID = name t.Instructions = "#!" + name return SetDefaults(t), ok diff --git a/pkg/loader/loader.go b/pkg/loader/loader.go index f2679c6f..5a907f5b 100644 --- a/pkg/loader/loader.go +++ b/pkg/loader/loader.go @@ -132,7 +132,7 @@ func loadLocal(base *source, name string) (*source, bool, error) { }, true, nil } -func loadProgram(data []byte, into *types.Program, targetToolName string) (types.Tool, error) { +func loadProgram(data []byte, into *types.Program, targetToolName, defaultModel string) (types.Tool, error) { var ext types.Program if err := json.Unmarshal(data[len(assemble.Header):], &ext); err != nil { @@ -141,7 +141,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types into.ToolSet = make(map[string]types.Tool, len(ext.ToolSet)) for k, v := range ext.ToolSet { - if builtinTool, ok := builtin.Builtin(k); ok { + if builtinTool, ok := builtin.BuiltinWithDefaultModel(k, defaultModel); ok { v = builtinTool } into.ToolSet[k] = v @@ -186,11 +186,11 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T { return openAPIDocument } -func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) { +func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName, defaultModel string) ([]types.Tool, error) { data := base.Content if bytes.HasPrefix(data, assemble.Header) { - tool, err := loadProgram(data, prg, targetToolName) + tool, err := loadProgram(data, prg, targetToolName, defaultModel) if err != nil { return nil, err } @@ -310,17 +310,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base localTools[strings.ToLower(tool.Parameters.Name)] = tool } - return linkAll(ctx, cache, prg, base, targetTools, localTools) + return linkAll(ctx, cache, prg, base, targetTools, localTools, defaultModel) } -func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) { +func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet, defaultModel string) (result []types.Tool, _ error) { localToolsMapping := make(map[string]string, len(tools)) for _, localTool := range localTools { localToolsMapping[strings.ToLower(localTool.Parameters.Name)] = localTool.ID } for _, tool := range tools { - tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping) + tool, err := link(ctx, cache, prg, base, tool, localTools, localToolsMapping, defaultModel) if err != nil { return nil, err } @@ -329,7 +329,7 @@ func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base return } -func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string) (types.Tool, error) { +func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet, localToolsMapping map[string]string, defaultModel string) (types.Tool, error) { if existing, ok := prg.ToolSet[tool.ID]; ok { return existing, nil } @@ -354,7 +354,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so linkedTool = existing } else { var err error - linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping) + linkedTool, err = link(ctx, cache, prg, base, localTool, localTools, localToolsMapping, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed linking %s at %s: %w", targetToolName, base, err) } @@ -364,7 +364,7 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so toolNames[targetToolName] = struct{}{} } else { toolName, subTool := types.SplitToolRef(targetToolName) - resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool) + resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool, defaultModel) if err != nil { return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err) } @@ -376,6 +376,10 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so tool.LocalTools = localToolsMapping + if tool.ModelName == "" { + tool.ModelName = defaultModel + } + tool = builtin.SetDefaults(tool) prg.ToolSet[tool.ID] = tool @@ -405,7 +409,7 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. Path: locationPath, Name: locationName, Location: opt.Location, - }, subToolName) + }, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -414,20 +418,26 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts .. } type Options struct { - Cache *cache.Client - Location string + Cache *cache.Client + Location string + DefaultModel string } func complete(opts ...Options) (result Options) { for _, opt := range opts { result.Cache = types.FirstSet(opt.Cache, result.Cache) result.Location = types.FirstSet(opt.Location, result.Location) + result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) } if result.Location == "" { result.Location = "inline" } + if result.DefaultModel == "" { + result.DefaultModel = builtin.GetDefaultModel() + } + return } @@ -451,7 +461,7 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty Name: name, ToolSet: types.ToolSet{}, } - tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName) + tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName, opt.DefaultModel) if err != nil { return types.Program{}, err } @@ -459,9 +469,9 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty return prg, nil } -func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) { +func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool, defaultModel string) ([]types.Tool, error) { if subTool == "" { - t, ok := builtin.Builtin(name) + t, ok := builtin.BuiltinWithDefaultModel(name, defaultModel) if ok { prg.ToolSet[t.ID] = t return []types.Tool{t}, nil @@ -473,7 +483,7 @@ func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base return nil, err } - result, err := readTool(ctx, cache, prg, s, subTool) + result, err := readTool(ctx, cache, prg, s, subTool, defaultModel) if err != nil { return nil, err } diff --git a/pkg/loader/openapi_test.go b/pkg/loader/openapi_test.go index 1a7eaa76..423246d1 100644 --- a/pkg/loader/openapi_test.go +++ b/pkg/loader/openapi_test.go @@ -26,7 +26,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err, "failed to read openapi v3") require.Equal(t, 3, numOpenAPITools(prgv3.ToolSet), "expected 3 openapi tools") @@ -35,7 +35,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.json") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2") require.Equal(t, 3, numOpenAPITools(prgv2json.ToolSet), "expected 3 openapi tools") @@ -44,7 +44,7 @@ func TestLoadOpenAPI(t *testing.T) { } datav2, err = os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "", "") require.NoError(t, err, "failed to read openapi v2 (yaml)") require.Equal(t, 3, numOpenAPITools(prgv2yaml.ToolSet), "expected 3 openapi tools") @@ -57,7 +57,7 @@ func TestOpenAPIv3(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -69,7 +69,7 @@ func TestOpenAPIv3NoOperationIDs(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -81,7 +81,7 @@ func TestOpenAPIv2(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) @@ -94,7 +94,7 @@ func TestOpenAPIv3Revamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -107,7 +107,7 @@ func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) { } datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "") + _, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi")) @@ -120,7 +120,7 @@ func TestOpenAPIv2Revamp(t *testing.T) { } datav2, err := os.ReadFile("testdata/openapi_v2.yaml") require.NoError(t, err) - _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "") + _, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "") require.NoError(t, err) autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi")) diff --git a/pkg/sdkserver/run.go b/pkg/sdkserver/run.go index b6b5a049..1c0f7c4b 100644 --- a/pkg/sdkserver/run.go +++ b/pkg/sdkserver/run.go @@ -32,7 +32,11 @@ func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, lo } defer g.Close(false) - prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache}) + defaultModel := opts.OpenAI.DefaultModel + if defaultModel == "" { + defaultModel = s.gptscriptOpts.OpenAI.DefaultModel + } + prg, err := programLoader(ctx, toolDef.String(), subTool, loader.Options{Cache: g.Cache, DefaultModel: defaultModel}) if err != nil { writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err)) return