diff --git a/client.go b/client.go index e2e85f4..3114b30 100644 --- a/client.go +++ b/client.go @@ -27,13 +27,14 @@ const relativeToBinaryPath = "" type Client interface { Run(context.Context, string, Options) (*Run, error) Evaluate(context.Context, Options, ...fmt.Stringer) (*Run, error) - Parse(ctx context.Context, fileName string) ([]Node, error) - ParseTool(ctx context.Context, toolDef string) ([]Node, error) - Version(ctx context.Context) (string, error) - Fmt(ctx context.Context, nodes []Node) (string, error) - ListTools(ctx context.Context) (string, error) - ListModels(ctx context.Context) ([]string, error) - Confirm(ctx context.Context, resp AuthResponse) error + Parse(context.Context, string) ([]Node, error) + ParseTool(context.Context, string) ([]Node, error) + Version(context.Context) (string, error) + Fmt(context.Context, []Node) (string, error) + ListTools(context.Context) (string, error) + ListModels(context.Context) ([]string, error) + Confirm(context.Context, AuthResponse) error + PromptResponse(context.Context, PromptResponse) error Close() } @@ -208,6 +209,11 @@ func (c *client) Confirm(ctx context.Context, resp AuthResponse) error { return err } +func (c *client) PromptResponse(ctx context.Context, resp PromptResponse) error { + _, err := c.runBasicCommand(ctx, "prompt-response/"+resp.ID, resp.Response) + return err +} + func (c *client) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) { run := &Run{ url: c.gptscriptURL, diff --git a/client_test.go b/client_test.go index 39b840a..a00a34c 100644 --- a/client_test.go +++ b/client_test.go @@ -259,7 +259,7 @@ func TestStreamEvaluate(t *testing.T) { run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tool) if err != nil { - t.Errorf("Error executing tool: %v", err) + t.Fatalf("Error executing tool: %v", err) } for e := range run.Events() { @@ -297,7 +297,7 @@ func TestStreamRun(t *testing.T) { var eventContent string run, err := c.Run(context.Background(), wd+"/test/catcher.gpt", Options{IncludeEvents: true}) if err != nil { - t.Errorf("Error executing file: %v", err) + t.Fatalf("Error executing file: %v", err) } for e := range run.Events() { @@ -618,7 +618,7 @@ func TestToolWithGlobalTools(t *testing.T) { run, err := c.Run(context.Background(), wd+"/test/global-tools.gpt", Options{DisableCache: true, IncludeEvents: true}) if err != nil { - t.Errorf("Error executing tool: %v", err) + t.Fatalf("Error executing tool: %v", err) } for e := range run.Events() { @@ -808,6 +808,90 @@ func TestConfirmDeny(t *testing.T) { } } +func TestPrompt(t *testing.T) { + var eventContent string + tools := []fmt.Stringer{ + &ToolDef{ + Instructions: "Use the sys.prompt user to ask the user for 'first name' which is not sensitive. After you get their first name, say hello.", + Tools: []string{"sys.prompt"}, + }, + } + + run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tools...) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Wait for the prompt event + var promptFrame *PromptFrame + for e := range run.Events() { + if e.Call != nil { + for _, o := range e.Call.Output { + eventContent += o.Content + } + } + if e.Prompt != nil { + if e.Prompt.Type == EventTypePrompt { + promptFrame = e.Prompt + break + } + } + } + + if promptFrame == nil { + t.Fatalf("No prompt call event") + } + + if promptFrame.Sensitive { + t.Errorf("Unexpected sensitive prompt event: %v", promptFrame.Sensitive) + } + + if !strings.Contains(promptFrame.Message, "first name") { + t.Errorf("unexpected confirm input: %s", promptFrame.Message) + } + + if len(promptFrame.Fields) != 1 { + t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields)) + } + + if promptFrame.Fields[0] != "first name" { + t.Errorf("Unexpected field: %s", promptFrame.Fields[0]) + } + + if err = c.PromptResponse(context.Background(), PromptResponse{ + ID: promptFrame.ID, + Response: map[string]string{promptFrame.Fields[0]: "Clicky"}, + }); err != nil { + t.Errorf("Error responding: %v", err) + } + + // Read the remainder of the events + for e := range run.Events() { + if e.Call != nil { + for _, o := range e.Call.Output { + eventContent += o.Content + } + } + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %v", err) + } + + if !strings.Contains(eventContent, "Clicky") { + t.Errorf("Unexpected event output: %s", eventContent) + } + + if !strings.Contains(out, "Hello") || !strings.Contains(out, "Clicky") { + t.Errorf("Unexpected output: %s", out) + } + + if len(run.ErrorOutput()) != 0 { + t.Errorf("Should have no stderr output: %v", run.ErrorOutput()) + } +} + func TestGetCommand(t *testing.T) { currentEnvVar := os.Getenv("GPTSCRIPT_BIN") t.Cleanup(func() { diff --git a/frame.go b/frame.go index bd96079..847781c 100644 --- a/frame.go +++ b/frame.go @@ -1,12 +1,33 @@ package gptscript -import ( - "time" +import "time" + +type ToolCategory string + +type EventType string + +const ( + CredentialToolCategory ToolCategory = "credential" + ContextToolCategory ToolCategory = "context" + NoCategory ToolCategory = "" + + EventTypeRunStart EventType = "runStart" + EventTypeCallStart EventType = "callStart" + EventTypeCallContinue EventType = "callContinue" + EventTypeCallSubCalls EventType = "callSubCalls" + EventTypeCallProgress EventType = "callProgress" + EventTypeChat EventType = "callChat" + EventTypeCallConfirm EventType = "callConfirm" + EventTypeCallFinish EventType = "callFinish" + EventTypeRunFinish EventType = "runFinish" + + EventTypePrompt EventType = "prompt" ) type Frame struct { - Run *RunFrame `json:"run,omitempty"` - Call *CallFrame `json:"call,omitempty"` + Run *RunFrame `json:"run,omitempty"` + Call *CallFrame `json:"call,omitempty"` + Prompt *PromptFrame `json:"prompt,omitempty"` } type RunFrame struct { @@ -74,24 +95,11 @@ type InputContext struct { Content string `json:"content,omitempty"` } -type ToolCategory string - -const ( - CredentialToolCategory ToolCategory = "credential" - ContextToolCategory ToolCategory = "context" - NoCategory ToolCategory = "" -) - -type EventType string - -const ( - EventTypeRunStart EventType = "runStart" - EventTypeCallStart EventType = "callStart" - EventTypeCallContinue EventType = "callContinue" - EventTypeCallSubCalls EventType = "callSubCalls" - EventTypeCallProgress EventType = "callProgress" - EventTypeChat EventType = "callChat" - EventTypeCallConfirm EventType = "callConfirm" - EventTypeCallFinish EventType = "callFinish" - EventTypeRunFinish EventType = "runFinish" -) +type PromptFrame struct { + ID string `json:"id,omitempty"` + Type EventType `json:"type,omitempty"` + Time time.Time `json:"time,omitempty"` + Message string `json:"message,omitempty"` + Fields []string `json:"fields,omitempty"` + Sensitive bool `json:"sensitive,omitempty"` +} diff --git a/prompt.go b/prompt.go new file mode 100644 index 0000000..59bb762 --- /dev/null +++ b/prompt.go @@ -0,0 +1,6 @@ +package gptscript + +type PromptResponse struct { + ID string `json:"id,omitempty"` + Response map[string]string `json:"response,omitempty"` +}