Skip to content

Commit 3e8b23c

Browse files
authored
feat: add dataset functions (#70)
Signed-off-by: Grant Linville <[email protected]>
1 parent 9064f15 commit 3e8b23c

File tree

4 files changed

+268
-2
lines changed

4 files changed

+268
-2
lines changed

datasets.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package gptscript
2+
3+
type DatasetElementMeta struct {
4+
Name string `json:"name"`
5+
Description string `json:"description"`
6+
}
7+
8+
type DatasetElement struct {
9+
DatasetElementMeta `json:",inline"`
10+
Contents string `json:"contents"`
11+
}
12+
13+
type DatasetMeta struct {
14+
ID string `json:"id"`
15+
Name string `json:"name"`
16+
Description string `json:"description"`
17+
}
18+
19+
type Dataset struct {
20+
DatasetMeta `json:",inline"`
21+
BaseDir string `json:"baseDir,omitempty"`
22+
Elements map[string]DatasetElementMeta `json:"elements"`
23+
}
24+
25+
type datasetRequest struct {
26+
Input string `json:"input"`
27+
Workspace string `json:"workspace"`
28+
DatasetToolRepo string `json:"datasetToolRepo"`
29+
}
30+
31+
type createDatasetArgs struct {
32+
Name string `json:"datasetName"`
33+
Description string `json:"datasetDescription"`
34+
}
35+
36+
type addDatasetElementArgs struct {
37+
DatasetID string `json:"datasetID"`
38+
ElementName string `json:"elementName"`
39+
ElementDescription string `json:"elementDescription"`
40+
ElementContent string `json:"elementContent"`
41+
}
42+
43+
type listDatasetElementArgs struct {
44+
DatasetID string `json:"datasetID"`
45+
}
46+
47+
type getDatasetElementArgs struct {
48+
DatasetID string `json:"datasetID"`
49+
Element string `json:"element"`
50+
}

gptscript.go

+171
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"encoding/base64"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"log/slog"
@@ -388,6 +389,176 @@ func (g *GPTScript) DeleteCredential(ctx context.Context, credCtx, name string)
388389
return err
389390
}
390391

392+
// Dataset methods
393+
394+
func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) {
395+
if workspace == "" {
396+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
397+
}
398+
399+
out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
400+
Input: "{}",
401+
Workspace: workspace,
402+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
403+
})
404+
405+
if err != nil {
406+
return nil, err
407+
}
408+
409+
if strings.HasPrefix(out, "ERROR:") {
410+
return nil, errors.New(out)
411+
}
412+
413+
var datasets []DatasetMeta
414+
if err = json.Unmarshal([]byte(out), &datasets); err != nil {
415+
return nil, err
416+
}
417+
return datasets, nil
418+
}
419+
420+
func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) {
421+
if workspace == "" {
422+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
423+
}
424+
425+
args := createDatasetArgs{
426+
Name: name,
427+
Description: description,
428+
}
429+
argsJSON, err := json.Marshal(args)
430+
if err != nil {
431+
return Dataset{}, fmt.Errorf("failed to marshal dataset args: %w", err)
432+
}
433+
434+
out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
435+
Input: string(argsJSON),
436+
Workspace: workspace,
437+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
438+
})
439+
440+
if err != nil {
441+
return Dataset{}, err
442+
}
443+
444+
if strings.HasPrefix(out, "ERROR:") {
445+
return Dataset{}, errors.New(out)
446+
}
447+
448+
var dataset Dataset
449+
if err = json.Unmarshal([]byte(out), &dataset); err != nil {
450+
return Dataset{}, err
451+
}
452+
return dataset, nil
453+
}
454+
455+
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
456+
if workspace == "" {
457+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
458+
}
459+
460+
args := addDatasetElementArgs{
461+
DatasetID: datasetID,
462+
ElementName: elementName,
463+
ElementDescription: elementDescription,
464+
ElementContent: elementContent,
465+
}
466+
argsJSON, err := json.Marshal(args)
467+
if err != nil {
468+
return DatasetElementMeta{}, fmt.Errorf("failed to marshal element args: %w", err)
469+
}
470+
471+
out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
472+
Input: string(argsJSON),
473+
Workspace: workspace,
474+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
475+
})
476+
477+
if err != nil {
478+
return DatasetElementMeta{}, err
479+
}
480+
481+
if strings.HasPrefix(out, "ERROR:") {
482+
return DatasetElementMeta{}, errors.New(out)
483+
}
484+
485+
var element DatasetElementMeta
486+
if err = json.Unmarshal([]byte(out), &element); err != nil {
487+
return DatasetElementMeta{}, err
488+
}
489+
return element, nil
490+
}
491+
492+
func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) {
493+
if workspace == "" {
494+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
495+
}
496+
497+
args := listDatasetElementArgs{
498+
DatasetID: datasetID,
499+
}
500+
argsJSON, err := json.Marshal(args)
501+
if err != nil {
502+
return nil, fmt.Errorf("failed to marshal element args: %w", err)
503+
}
504+
505+
out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
506+
Input: string(argsJSON),
507+
Workspace: workspace,
508+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
509+
})
510+
511+
if err != nil {
512+
return nil, err
513+
}
514+
515+
if strings.HasPrefix(out, "ERROR:") {
516+
return nil, errors.New(out)
517+
}
518+
519+
var elements []DatasetElementMeta
520+
if err = json.Unmarshal([]byte(out), &elements); err != nil {
521+
return nil, err
522+
}
523+
return elements, nil
524+
}
525+
526+
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) {
527+
if workspace == "" {
528+
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
529+
}
530+
531+
args := getDatasetElementArgs{
532+
DatasetID: datasetID,
533+
Element: elementName,
534+
}
535+
argsJSON, err := json.Marshal(args)
536+
if err != nil {
537+
return DatasetElement{}, fmt.Errorf("failed to marshal element args: %w", err)
538+
}
539+
540+
out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
541+
Input: string(argsJSON),
542+
Workspace: workspace,
543+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
544+
})
545+
546+
if err != nil {
547+
return DatasetElement{}, err
548+
}
549+
550+
if strings.HasPrefix(out, "ERROR:") {
551+
return DatasetElement{}, errors.New(out)
552+
}
553+
554+
var element DatasetElement
555+
if err = json.Unmarshal([]byte(out), &element); err != nil {
556+
return DatasetElement{}, err
557+
}
558+
559+
return element, nil
560+
}
561+
391562
func (g *GPTScript) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) {
392563
run := &Run{
393564
url: g.globalOpts.URL,

gptscript_test.go

+45-2
Original file line numberDiff line numberDiff line change
@@ -670,7 +670,7 @@ func TestParseToolWithTextNode(t *testing.T) {
670670
t.Fatalf("No text node found")
671671
}
672672

673-
if tools[1].TextNode.Text != "hello\n" {
673+
if strings.TrimSpace(tools[1].TextNode.Text) != "hello" {
674674
t.Errorf("Unexpected text: %s", tools[1].TextNode.Text)
675675
}
676676
if tools[1].TextNode.Fmt != "markdown" {
@@ -1047,7 +1047,7 @@ func TestConfirmDeny(t *testing.T) {
10471047
return
10481048
}
10491049

1050-
if !strings.Contains(confirmCallEvent.Input, "\"ls\"") {
1050+
if !strings.Contains(confirmCallEvent.Input, "ls") {
10511051
t.Errorf("unexpected confirm input: %s", confirmCallEvent.Input)
10521052
}
10531053

@@ -1560,3 +1560,46 @@ func TestCredentials(t *testing.T) {
15601560
require.Error(t, err)
15611561
require.True(t, errors.As(err, &ErrNotFound{}))
15621562
}
1563+
1564+
func TestDatasets(t *testing.T) {
1565+
workspace, err := os.MkdirTemp("", "go-gptscript-test")
1566+
require.NoError(t, err)
1567+
defer func() {
1568+
_ = os.RemoveAll(workspace)
1569+
}()
1570+
1571+
// Create a dataset
1572+
dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset")
1573+
require.NoError(t, err)
1574+
require.Equal(t, "test-dataset", dataset.Name)
1575+
require.Equal(t, "This is a test dataset", dataset.Description)
1576+
require.Equal(t, 0, len(dataset.Elements))
1577+
1578+
// Add an element
1579+
elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content")
1580+
require.NoError(t, err)
1581+
require.Equal(t, "test-element", elementMeta.Name)
1582+
require.Equal(t, "This is a test element", elementMeta.Description)
1583+
1584+
// Get the element
1585+
element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element")
1586+
require.NoError(t, err)
1587+
require.Equal(t, "test-element", element.Name)
1588+
require.Equal(t, "This is a test element", element.Description)
1589+
require.Equal(t, "This is the content", element.Contents)
1590+
1591+
// List elements in the dataset
1592+
elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID)
1593+
require.NoError(t, err)
1594+
require.Equal(t, 1, len(elements))
1595+
require.Equal(t, "test-element", elements[0].Name)
1596+
require.Equal(t, "This is a test element", elements[0].Description)
1597+
1598+
// List datasets
1599+
datasets, err := g.ListDatasets(context.Background(), workspace)
1600+
require.NoError(t, err)
1601+
require.Equal(t, 1, len(datasets))
1602+
require.Equal(t, "test-dataset", datasets[0].Name)
1603+
require.Equal(t, "This is a test dataset", datasets[0].Description)
1604+
require.Equal(t, dataset.ID, datasets[0].ID)
1605+
}

opts.go

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ type GlobalOptions struct {
1111
DefaultModelProvider string `json:"DefaultModelProvider"`
1212
CacheDir string `json:"CacheDir"`
1313
Env []string `json:"env"`
14+
DatasetToolRepo string `json:"DatasetToolRepo"`
1415
}
1516

1617
func (g GlobalOptions) toEnv() []string {
@@ -41,6 +42,7 @@ func completeGlobalOptions(opts ...GlobalOptions) GlobalOptions {
4142
result.OpenAIBaseURL = firstSet(opt.OpenAIBaseURL, result.OpenAIBaseURL)
4243
result.DefaultModel = firstSet(opt.DefaultModel, result.DefaultModel)
4344
result.DefaultModelProvider = firstSet(opt.DefaultModelProvider, result.DefaultModelProvider)
45+
result.DatasetToolRepo = firstSet(opt.DatasetToolRepo, result.DatasetToolRepo)
4446
result.Env = append(result.Env, opt.Env...)
4547
}
4648
return result

0 commit comments

Comments
 (0)