Skip to content

Commit acc2dc8

Browse files
committed
feat: add datasets to sdkserver
Signed-off-by: Grant Linville <[email protected]>
1 parent cc5e5ed commit acc2dc8

File tree

4 files changed

+344
-0
lines changed

4 files changed

+344
-0
lines changed

pkg/cli/gptscript.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ type GPTScript struct {
7575
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
7676
DefaultModelProvider string `usage:"Default LLM model provider to use, this will override OpenAI settings"`
7777
GithubEnterpriseHostname string `usage:"The host name for a Github Enterprise instance to enable for remote loading" local:"true"`
78+
DatasetToolRepo string `usage:"The repo to use for dataset tools" default:"github.com/gptscript-ai/datasets" local:"true"`
7879

7980
readData []byte
8081
}
@@ -146,6 +147,7 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
146147
Workspace: r.Workspace,
147148
DisablePromptServer: r.UI,
148149
DefaultModelProvider: r.DefaultModelProvider,
150+
DatasetToolRepo: r.DatasetToolRepo,
149151
}
150152

151153
if r.Confirm {

pkg/gptscript/gptscript.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232

3333
var log = mvl.Package()
3434

35+
const defaultDatasetToolRepo = "github.com/gptscript-ai/datasets"
36+
3537
type GPTScript struct {
3638
Registry *llm.Registry
3739
Runner *runner.Runner
@@ -51,6 +53,7 @@ type Options struct {
5153
CredentialContexts []string
5254
Quiet *bool
5355
Workspace string
56+
DatasetToolRepo string
5457
DisablePromptServer bool
5558
Env []string
5659
}
@@ -66,6 +69,7 @@ func Complete(opts ...Options) Options {
6669
result.CredentialContexts = opt.CredentialContexts
6770
result.Quiet = types.FirstSet(opt.Quiet, result.Quiet)
6871
result.Workspace = types.FirstSet(opt.Workspace, result.Workspace)
72+
result.DatasetToolRepo = types.FirstSet(opt.DatasetToolRepo, result.DatasetToolRepo)
6973
result.Env = append(result.Env, opt.Env...)
7074
result.DisablePromptServer = types.FirstSet(opt.DisablePromptServer, result.DisablePromptServer)
7175
result.DefaultModelProvider = types.FirstSet(opt.DefaultModelProvider, result.DefaultModelProvider)
@@ -80,6 +84,9 @@ func Complete(opts ...Options) Options {
8084
if len(result.CredentialContexts) == 0 {
8185
result.CredentialContexts = []string{credentials.DefaultCredentialContext}
8286
}
87+
if result.DatasetToolRepo == "" {
88+
result.DatasetToolRepo = defaultDatasetToolRepo
89+
}
8390

8491
return result
8592
}

pkg/sdkserver/datasets.go

Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
package sdkserver
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"net/http"
7+
8+
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
9+
"github.com/gptscript-ai/gptscript/pkg/gptscript"
10+
"github.com/gptscript-ai/gptscript/pkg/loader"
11+
)
12+
13+
type datasetRequest struct {
14+
Input string `json:"input"`
15+
Workspace string `json:"workspace"`
16+
DatasetToolRepo string `json:"datasetToolRepo"`
17+
}
18+
19+
func (r datasetRequest) validate(requireInput bool) error {
20+
if r.Workspace == "" {
21+
return fmt.Errorf("workspace is required")
22+
} else if requireInput && r.Input == "" {
23+
return fmt.Errorf("input is required")
24+
}
25+
return nil
26+
}
27+
28+
func (r datasetRequest) opts(o gptscript.Options) gptscript.Options {
29+
opts := gptscript.Options{
30+
Cache: o.Cache,
31+
Monitor: o.Monitor,
32+
Runner: o.Runner,
33+
DatasetToolRepo: o.DatasetToolRepo,
34+
Workspace: r.Workspace,
35+
}
36+
if r.DatasetToolRepo != "" {
37+
opts.DatasetToolRepo = r.DatasetToolRepo
38+
}
39+
return opts
40+
}
41+
42+
func (s *server) listDatasets(w http.ResponseWriter, r *http.Request) {
43+
logger := gcontext.GetLogger(r.Context())
44+
45+
var req datasetRequest
46+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
47+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
48+
return
49+
}
50+
51+
if err := req.validate(false); err != nil {
52+
writeError(logger, w, http.StatusBadRequest, err)
53+
return
54+
}
55+
56+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
57+
if err != nil {
58+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
59+
return
60+
}
61+
62+
prg, err := loader.Program(r.Context(), "List Datasets from "+s.gptscriptOpts.DatasetToolRepo, "", loader.Options{
63+
Cache: g.Cache,
64+
})
65+
66+
if err != nil {
67+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
68+
return
69+
}
70+
71+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
72+
if err != nil {
73+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
74+
return
75+
}
76+
77+
writeResponse(logger, w, result)
78+
}
79+
80+
type createDatasetArgs struct {
81+
Name string `json:"dataset_name"`
82+
Description string `json:"dataset_description"`
83+
}
84+
85+
func (a createDatasetArgs) validate() error {
86+
if a.Name == "" {
87+
return fmt.Errorf("dataset_name is required")
88+
}
89+
return nil
90+
}
91+
92+
func (s *server) createDataset(w http.ResponseWriter, r *http.Request) {
93+
logger := gcontext.GetLogger(r.Context())
94+
95+
var req datasetRequest
96+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
97+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
98+
return
99+
}
100+
101+
if err := req.validate(true); err != nil {
102+
writeError(logger, w, http.StatusBadRequest, err)
103+
return
104+
}
105+
106+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
107+
if err != nil {
108+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
109+
return
110+
}
111+
112+
var args createDatasetArgs
113+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
114+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
115+
return
116+
}
117+
118+
if err := args.validate(); err != nil {
119+
writeError(logger, w, http.StatusBadRequest, err)
120+
return
121+
}
122+
123+
prg, err := loader.Program(r.Context(), "Create Dataset from "+s.gptscriptOpts.DatasetToolRepo, "", loader.Options{
124+
Cache: g.Cache,
125+
})
126+
127+
if err != nil {
128+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
129+
return
130+
}
131+
132+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
133+
if err != nil {
134+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
135+
return
136+
}
137+
138+
writeResponse(logger, w, result)
139+
}
140+
141+
type addDatasetElementArgs struct {
142+
DatasetID string `json:"dataset_id"`
143+
ElementName string `json:"element_name"`
144+
ElementDescription string `json:"element_description"`
145+
ElementContent string `json:"element_content"`
146+
}
147+
148+
func (a addDatasetElementArgs) validate() error {
149+
if a.DatasetID == "" {
150+
return fmt.Errorf("dataset_id is required")
151+
}
152+
if a.ElementName == "" {
153+
return fmt.Errorf("element_name is required")
154+
}
155+
if a.ElementContent == "" {
156+
return fmt.Errorf("element_content is required")
157+
}
158+
return nil
159+
}
160+
161+
func (s *server) addDatasetElement(w http.ResponseWriter, r *http.Request) {
162+
logger := gcontext.GetLogger(r.Context())
163+
164+
var req datasetRequest
165+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
166+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
167+
return
168+
}
169+
170+
if err := req.validate(true); err != nil {
171+
writeError(logger, w, http.StatusBadRequest, err)
172+
return
173+
}
174+
175+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
176+
if err != nil {
177+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
178+
return
179+
}
180+
181+
var args addDatasetElementArgs
182+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
183+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
184+
return
185+
}
186+
187+
if err := args.validate(); err != nil {
188+
writeError(logger, w, http.StatusBadRequest, err)
189+
return
190+
}
191+
192+
prg, err := loader.Program(r.Context(), "Add Element from "+s.gptscriptOpts.DatasetToolRepo, "", loader.Options{
193+
Cache: g.Cache,
194+
})
195+
if err != nil {
196+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
197+
return
198+
}
199+
200+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
201+
if err != nil {
202+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
203+
return
204+
}
205+
206+
writeResponse(logger, w, result)
207+
}
208+
209+
type listDatasetElementsArgs struct {
210+
DatasetID string `json:"dataset_id"`
211+
}
212+
213+
func (a listDatasetElementsArgs) validate() error {
214+
if a.DatasetID == "" {
215+
return fmt.Errorf("dataset_id is required")
216+
}
217+
return nil
218+
}
219+
220+
func (s *server) listDatasetElements(w http.ResponseWriter, r *http.Request) {
221+
logger := gcontext.GetLogger(r.Context())
222+
223+
var req datasetRequest
224+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
225+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
226+
return
227+
}
228+
229+
if err := req.validate(true); err != nil {
230+
writeError(logger, w, http.StatusBadRequest, err)
231+
return
232+
}
233+
234+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
235+
if err != nil {
236+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
237+
return
238+
}
239+
240+
var args listDatasetElementsArgs
241+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
242+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
243+
return
244+
}
245+
246+
if err := args.validate(); err != nil {
247+
writeError(logger, w, http.StatusBadRequest, err)
248+
return
249+
}
250+
251+
prg, err := loader.Program(r.Context(), "List Elements from "+s.gptscriptOpts.DatasetToolRepo, "", loader.Options{
252+
Cache: g.Cache,
253+
})
254+
if err != nil {
255+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
256+
return
257+
}
258+
259+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
260+
if err != nil {
261+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
262+
return
263+
}
264+
265+
writeResponse(logger, w, result)
266+
}
267+
268+
type getDatasetElementArgs struct {
269+
DatasetID string `json:"dataset_id"`
270+
Element string `json:"element"`
271+
}
272+
273+
func (a getDatasetElementArgs) validate() error {
274+
if a.DatasetID == "" {
275+
return fmt.Errorf("dataset_id is required")
276+
}
277+
if a.Element == "" {
278+
return fmt.Errorf("element is required")
279+
}
280+
return nil
281+
}
282+
283+
func (s *server) getDatasetElement(w http.ResponseWriter, r *http.Request) {
284+
logger := gcontext.GetLogger(r.Context())
285+
286+
var req datasetRequest
287+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
288+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
289+
return
290+
}
291+
292+
if err := req.validate(true); err != nil {
293+
writeError(logger, w, http.StatusBadRequest, err)
294+
return
295+
}
296+
297+
g, err := gptscript.New(r.Context(), req.opts(s.gptscriptOpts))
298+
if err != nil {
299+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to initialize gptscript: %w", err))
300+
return
301+
}
302+
303+
var args getDatasetElementArgs
304+
if err := json.Unmarshal([]byte(req.Input), &args); err != nil {
305+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to unmarshal input: %w", err))
306+
return
307+
}
308+
309+
if err := args.validate(); err != nil {
310+
writeError(logger, w, http.StatusBadRequest, err)
311+
return
312+
}
313+
314+
prg, err := loader.Program(r.Context(), "Get Element from "+s.gptscriptOpts.DatasetToolRepo, "", loader.Options{
315+
Cache: g.Cache,
316+
})
317+
if err != nil {
318+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
319+
return
320+
}
321+
322+
result, err := g.Run(r.Context(), prg, s.gptscriptOpts.Env, req.Input)
323+
if err != nil {
324+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to run program: %w", err))
325+
return
326+
}
327+
328+
writeResponse(logger, w, result)
329+
}

pkg/sdkserver/routes.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ func (s *server) addRoutes(mux *http.ServeMux) {
6666
mux.HandleFunc("POST /credentials/create", s.createCredential)
6767
mux.HandleFunc("POST /credentials/reveal", s.revealCredential)
6868
mux.HandleFunc("POST /credentials/delete", s.deleteCredential)
69+
70+
mux.HandleFunc("POST /datasets", s.listDatasets)
71+
mux.HandleFunc("POST /datasets/create", s.createDataset)
72+
mux.HandleFunc("POST /datasets/list-elements", s.listDatasetElements)
73+
mux.HandleFunc("POST /datasets/get-element", s.getDatasetElement)
74+
mux.HandleFunc("POST /datasets/add-element", s.addDatasetElement)
6975
}
7076

7177
// health just provides an endpoint for checking whether the server is running and accessible.

0 commit comments

Comments
 (0)