Skip to content

Commit cd749df

Browse files
authored
chore: add Add Elements method and update for workspace provider (#78)
Signed-off-by: Grant Linville <[email protected]>
1 parent 3901872 commit cd749df

File tree

2 files changed

+93
-35
lines changed

2 files changed

+93
-35
lines changed

datasets.go

+57-23
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ type Dataset struct {
3030
}
3131

3232
type datasetRequest struct {
33-
Input string `json:"input"`
34-
Workspace string `json:"workspace"`
35-
DatasetToolRepo string `json:"datasetToolRepo"`
33+
Input string `json:"input"`
34+
WorkspaceID string `json:"workspaceID"`
35+
DatasetToolRepo string `json:"datasetToolRepo"`
36+
Env []string `json:"env"`
3637
}
3738

3839
type createDatasetArgs struct {
@@ -47,6 +48,11 @@ type addDatasetElementArgs struct {
4748
ElementContent string `json:"elementContent"`
4849
}
4950

51+
type addDatasetElementsArgs struct {
52+
DatasetID string `json:"datasetID"`
53+
Elements []DatasetElement `json:"elements"`
54+
}
55+
5056
type listDatasetElementArgs struct {
5157
DatasetID string `json:"datasetID"`
5258
}
@@ -56,15 +62,16 @@ type getDatasetElementArgs struct {
5662
Element string `json:"element"`
5763
}
5864

59-
func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]DatasetMeta, error) {
60-
if workspace == "" {
61-
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
65+
func (g *GPTScript) ListDatasets(ctx context.Context, workspaceID string) ([]DatasetMeta, error) {
66+
if workspaceID == "" {
67+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
6268
}
6369

6470
out, err := g.runBasicCommand(ctx, "datasets", datasetRequest{
6571
Input: "{}",
66-
Workspace: workspace,
72+
WorkspaceID: workspaceID,
6773
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
74+
Env: g.globalOpts.Env,
6875
})
6976
if err != nil {
7077
return nil, err
@@ -77,9 +84,9 @@ func (g *GPTScript) ListDatasets(ctx context.Context, workspace string) ([]Datas
7784
return datasets, nil
7885
}
7986

80-
func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, description string) (Dataset, error) {
81-
if workspace == "" {
82-
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
87+
func (g *GPTScript) CreateDataset(ctx context.Context, workspaceID, name, description string) (Dataset, error) {
88+
if workspaceID == "" {
89+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
8390
}
8491

8592
args := createDatasetArgs{
@@ -93,8 +100,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript
93100

94101
out, err := g.runBasicCommand(ctx, "datasets/create", datasetRequest{
95102
Input: string(argsJSON),
96-
Workspace: workspace,
103+
WorkspaceID: workspaceID,
97104
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
105+
Env: g.globalOpts.Env,
98106
})
99107
if err != nil {
100108
return Dataset{}, err
@@ -107,9 +115,9 @@ func (g *GPTScript) CreateDataset(ctx context.Context, workspace, name, descript
107115
return dataset, nil
108116
}
109117

110-
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
111-
if workspace == "" {
112-
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
118+
func (g *GPTScript) AddDatasetElement(ctx context.Context, workspaceID, datasetID, elementName, elementDescription, elementContent string) (DatasetElementMeta, error) {
119+
if workspaceID == "" {
120+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
113121
}
114122

115123
args := addDatasetElementArgs{
@@ -125,8 +133,9 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,
125133

126134
out, err := g.runBasicCommand(ctx, "datasets/add-element", datasetRequest{
127135
Input: string(argsJSON),
128-
Workspace: workspace,
136+
WorkspaceID: workspaceID,
129137
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
138+
Env: g.globalOpts.Env,
130139
})
131140
if err != nil {
132141
return DatasetElementMeta{}, err
@@ -139,9 +148,32 @@ func (g *GPTScript) AddDatasetElement(ctx context.Context, workspace, datasetID,
139148
return element, nil
140149
}
141150

142-
func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetID string) ([]DatasetElementMeta, error) {
143-
if workspace == "" {
144-
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
151+
func (g *GPTScript) AddDatasetElements(ctx context.Context, workspaceID, datasetID string, elements []DatasetElement) error {
152+
if workspaceID == "" {
153+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
154+
}
155+
156+
args := addDatasetElementsArgs{
157+
DatasetID: datasetID,
158+
Elements: elements,
159+
}
160+
argsJSON, err := json.Marshal(args)
161+
if err != nil {
162+
return fmt.Errorf("failed to marshal element args: %w", err)
163+
}
164+
165+
_, err = g.runBasicCommand(ctx, "datasets/add-elements", datasetRequest{
166+
Input: string(argsJSON),
167+
WorkspaceID: workspaceID,
168+
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
169+
Env: g.globalOpts.Env,
170+
})
171+
return err
172+
}
173+
174+
func (g *GPTScript) ListDatasetElements(ctx context.Context, workspaceID, datasetID string) ([]DatasetElementMeta, error) {
175+
if workspaceID == "" {
176+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
145177
}
146178

147179
args := listDatasetElementArgs{
@@ -154,8 +186,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI
154186

155187
out, err := g.runBasicCommand(ctx, "datasets/list-elements", datasetRequest{
156188
Input: string(argsJSON),
157-
Workspace: workspace,
189+
WorkspaceID: workspaceID,
158190
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
191+
Env: g.globalOpts.Env,
159192
})
160193
if err != nil {
161194
return nil, err
@@ -168,9 +201,9 @@ func (g *GPTScript) ListDatasetElements(ctx context.Context, workspace, datasetI
168201
return elements, nil
169202
}
170203

171-
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID, elementName string) (DatasetElement, error) {
172-
if workspace == "" {
173-
workspace = os.Getenv("GPTSCRIPT_WORKSPACE_DIR")
204+
func (g *GPTScript) GetDatasetElement(ctx context.Context, workspaceID, datasetID, elementName string) (DatasetElement, error) {
205+
if workspaceID == "" {
206+
workspaceID = os.Getenv("GPTSCRIPT_WORKSPACE_ID")
174207
}
175208

176209
args := getDatasetElementArgs{
@@ -184,8 +217,9 @@ func (g *GPTScript) GetDatasetElement(ctx context.Context, workspace, datasetID,
184217

185218
out, err := g.runBasicCommand(ctx, "datasets/get-element", datasetRequest{
186219
Input: string(argsJSON),
187-
Workspace: workspace,
220+
WorkspaceID: workspaceID,
188221
DatasetToolRepo: g.globalOpts.DatasetToolRepo,
222+
Env: g.globalOpts.Env,
189223
})
190224
if err != nil {
191225
return DatasetElement{}, err

datasets_test.go

+36-12
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,72 @@ package gptscript
22

33
import (
44
"context"
5-
"os"
65
"testing"
76

87
"github.com/stretchr/testify/require"
98
)
109

1110
func TestDatasets(t *testing.T) {
12-
workspace, err := os.MkdirTemp("", "go-gptscript-test")
11+
workspaceID, err := g.CreateWorkspace(context.Background(), "directory")
1312
require.NoError(t, err)
13+
1414
defer func() {
15-
_ = os.RemoveAll(workspace)
15+
_ = g.DeleteWorkspace(context.Background(), DeleteWorkspaceOptions{WorkspaceID: workspaceID})
1616
}()
1717

1818
// Create a dataset
19-
dataset, err := g.CreateDataset(context.Background(), workspace, "test-dataset", "This is a test dataset")
19+
dataset, err := g.CreateDataset(context.Background(), workspaceID, "test-dataset", "This is a test dataset")
2020
require.NoError(t, err)
2121
require.Equal(t, "test-dataset", dataset.Name)
2222
require.Equal(t, "This is a test dataset", dataset.Description)
2323
require.Equal(t, 0, len(dataset.Elements))
2424

2525
// Add an element
26-
elementMeta, err := g.AddDatasetElement(context.Background(), workspace, dataset.ID, "test-element", "This is a test element", "This is the content")
26+
elementMeta, err := g.AddDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element", "This is a test element", "This is the content")
2727
require.NoError(t, err)
2828
require.Equal(t, "test-element", elementMeta.Name)
2929
require.Equal(t, "This is a test element", elementMeta.Description)
3030

31-
// Get the element
32-
element, err := g.GetDatasetElement(context.Background(), workspace, dataset.ID, "test-element")
31+
// Add two more
32+
err = g.AddDatasetElements(context.Background(), workspaceID, dataset.ID, []DatasetElement{
33+
{
34+
DatasetElementMeta: DatasetElementMeta{
35+
Name: "test-element-2",
36+
Description: "This is a test element 2",
37+
},
38+
Contents: "This is the content 2",
39+
},
40+
{
41+
DatasetElementMeta: DatasetElementMeta{
42+
Name: "test-element-3",
43+
Description: "This is a test element 3",
44+
},
45+
Contents: "This is the content 3",
46+
},
47+
})
48+
require.NoError(t, err)
49+
50+
// Get the first element
51+
element, err := g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element")
3352
require.NoError(t, err)
3453
require.Equal(t, "test-element", element.Name)
3554
require.Equal(t, "This is a test element", element.Description)
3655
require.Equal(t, "This is the content", element.Contents)
3756

57+
// Get the third element
58+
element, err = g.GetDatasetElement(context.Background(), workspaceID, dataset.ID, "test-element-3")
59+
require.NoError(t, err)
60+
require.Equal(t, "test-element-3", element.Name)
61+
require.Equal(t, "This is a test element 3", element.Description)
62+
require.Equal(t, "This is the content 3", element.Contents)
63+
3864
// List elements in the dataset
39-
elements, err := g.ListDatasetElements(context.Background(), workspace, dataset.ID)
65+
elements, err := g.ListDatasetElements(context.Background(), workspaceID, dataset.ID)
4066
require.NoError(t, err)
41-
require.Equal(t, 1, len(elements))
42-
require.Equal(t, "test-element", elements[0].Name)
43-
require.Equal(t, "This is a test element", elements[0].Description)
67+
require.Equal(t, 3, len(elements))
4468

4569
// List datasets
46-
datasets, err := g.ListDatasets(context.Background(), workspace)
70+
datasets, err := g.ListDatasets(context.Background(), workspaceID)
4771
require.NoError(t, err)
4872
require.Equal(t, 1, len(datasets))
4973
require.Equal(t, "test-dataset", datasets[0].Name)

0 commit comments

Comments
 (0)