Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions pkg/loader/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,20 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}
opt := complete(opts...)

var locationPath, locationName string
if opt.Location != "" {
locationPath = path.Dir(opt.Location)
locationName = path.Base(opt.Location)
}

prg := types.Program{
ToolSet: types.ToolSet{},
}
tools, err := readTool(ctx, opt.Cache, &prg, &source{
Content: []byte(content),
Location: "inline",
Path: locationPath,
Name: locationName,
Location: opt.Location,
}, subToolName)
if err != nil {
return types.Program{}, err
Expand All @@ -388,12 +396,18 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
}

type Options struct {
Cache *cache.Client
Cache *cache.Client
Location string
}

func complete(opts ...Options) (result Options) {
for _, opt := range opts {
result.Cache = types.FirstSet(opt.Cache, result.Cache)
result.Location = types.FirstSet(opt.Location, result.Location)
}

if result.Location == "" {
result.Location = "inline"
}

return
Expand Down
38 changes: 25 additions & 13 deletions pkg/loader/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
req.Header.Set("Authorization", "Bearer "+bearerToken)
}

data, err := getWithDefaults(req)
data, defaulted, err := getWithDefaults(req)
if err != nil {
return nil, false, fmt.Errorf("error loading %s: %v", url, err)
}

if defaulted != "" {
pathString = url
name = defaulted
if repo != nil {
repo.Path = path.Join(repo.Path, repo.Name)
repo.Name = defaulted
}
}

log.Debugf("opened %s", url)

result := &source{
Expand All @@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
return result, true, nil
}

func getWithDefaults(req *http.Request) ([]byte, error) {
func getWithDefaults(req *http.Request) ([]byte, string, error) {
originalPath := req.URL.Path

// First, try to get the original path as is. It might be an OpenAPI definition.
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusOK {
if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 {
return toolBytes, nil
}
toolBytes, err := io.ReadAll(resp.Body)
return toolBytes, "", err
}

base := path.Base(originalPath)
if strings.Contains(base, ".") {
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

for i, def := range types.DefaultFiles {
base := path.Base(originalPath)
if !strings.Contains(base, ".") {
req.URL.Path = path.Join(originalPath, def)
}
req.URL.Path = path.Join(originalPath, def)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
return nil, "", err
}
defer resp.Body.Close()

Expand All @@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) {
}

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
}

return io.ReadAll(resp.Body)
data, err := io.ReadAll(resp.Body)
return data, def, err
}

panic("unreachable")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/sdkserver/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
logger.Debugf("executing tool: %+v", reqObject)
var (
def fmt.Stringer = &reqObject.ToolDefs
programLoader loaderFunc = loader.ProgramFromSource
programLoader = loaderWithLocation(loader.ProgramFromSource, reqObject.Location)
)
if reqObject.Content != "" {
def = &reqObject.content
Expand Down
8 changes: 8 additions & 0 deletions pkg/sdkserver/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ import (

type loaderFunc func(context.Context, string, string, ...loader.Options) (types.Program, error)

func loaderWithLocation(f loaderFunc, loc string) loaderFunc {
return func(ctx context.Context, s string, s2 string, options ...loader.Options) (types.Program, error) {
return f(ctx, s, s2, append(options, loader.Options{
Location: loc,
})...)
}
}

func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) {
g, err := gptscript.New(ctx, s.gptscriptOpts, opts)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions pkg/sdkserver/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ type toolOrFileRequest struct {
CredentialContext string `json:"credentialContext"`
CredentialOverrides []string `json:"credentialOverrides"`
Confirm bool `json:"confirm"`
Location string `json:"location,omitempty"`
}

type content struct {
Expand Down