Skip to content

Commit 74c092c

Browse files
committed
feat: add prompt support
Signed-off-by: Donnie Adams <[email protected]>
1 parent b9cbffb commit 74c092c

File tree

4 files changed

+139
-35
lines changed

4 files changed

+139
-35
lines changed

client.go

+13-7
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,14 @@ const relativeToBinaryPath = "<me>"
2727
type Client interface {
2828
Run(context.Context, string, Options) (*Run, error)
2929
Evaluate(context.Context, Options, ...fmt.Stringer) (*Run, error)
30-
Parse(ctx context.Context, fileName string) ([]Node, error)
31-
ParseTool(ctx context.Context, toolDef string) ([]Node, error)
32-
Version(ctx context.Context) (string, error)
33-
Fmt(ctx context.Context, nodes []Node) (string, error)
34-
ListTools(ctx context.Context) (string, error)
35-
ListModels(ctx context.Context) ([]string, error)
36-
Confirm(ctx context.Context, resp AuthResponse) error
30+
Parse(context.Context, string) ([]Node, error)
31+
ParseTool(context.Context, string) ([]Node, error)
32+
Version(context.Context) (string, error)
33+
Fmt(context.Context, []Node) (string, error)
34+
ListTools(context.Context) (string, error)
35+
ListModels(context.Context) ([]string, error)
36+
Confirm(context.Context, AuthResponse) error
37+
PromptResponse(context.Context, PromptResponse) error
3738
Close()
3839
}
3940

@@ -208,6 +209,11 @@ func (c *client) Confirm(ctx context.Context, resp AuthResponse) error {
208209
return err
209210
}
210211

212+
func (c *client) PromptResponse(ctx context.Context, resp PromptResponse) error {
213+
_, err := c.runBasicCommand(ctx, "prompt-response/"+resp.ID, resp.Response)
214+
return err
215+
}
216+
211217
func (c *client) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) {
212218
run := &Run{
213219
url: c.gptscriptURL,

client_test.go

+87-3
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func TestStreamEvaluate(t *testing.T) {
259259

260260
run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tool)
261261
if err != nil {
262-
t.Errorf("Error executing tool: %v", err)
262+
t.Fatalf("Error executing tool: %v", err)
263263
}
264264

265265
for e := range run.Events() {
@@ -297,7 +297,7 @@ func TestStreamRun(t *testing.T) {
297297
var eventContent string
298298
run, err := c.Run(context.Background(), wd+"/test/catcher.gpt", Options{IncludeEvents: true})
299299
if err != nil {
300-
t.Errorf("Error executing file: %v", err)
300+
t.Fatalf("Error executing file: %v", err)
301301
}
302302

303303
for e := range run.Events() {
@@ -618,7 +618,7 @@ func TestToolWithGlobalTools(t *testing.T) {
618618

619619
run, err := c.Run(context.Background(), wd+"/test/global-tools.gpt", Options{DisableCache: true, IncludeEvents: true})
620620
if err != nil {
621-
t.Errorf("Error executing tool: %v", err)
621+
t.Fatalf("Error executing tool: %v", err)
622622
}
623623

624624
for e := range run.Events() {
@@ -808,6 +808,90 @@ func TestConfirmDeny(t *testing.T) {
808808
}
809809
}
810810

811+
func TestPrompt(t *testing.T) {
812+
var eventContent string
813+
tools := []fmt.Stringer{
814+
&ToolDef{
815+
Instructions: "Use the sys.prompt user to ask the user for 'first name' which is not sensitive. After you get their first name, say hello.",
816+
Tools: []string{"sys.prompt"},
817+
},
818+
}
819+
820+
run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tools...)
821+
if err != nil {
822+
t.Errorf("Error executing tool: %v", err)
823+
}
824+
825+
// Wait for the prompt event
826+
var promptFrame *PromptFrame
827+
for e := range run.Events() {
828+
if e.Call != nil {
829+
for _, o := range e.Call.Output {
830+
eventContent += o.Content
831+
}
832+
}
833+
if e.Prompt != nil {
834+
if e.Prompt.Type == EventTypePrompt {
835+
promptFrame = e.Prompt
836+
break
837+
}
838+
}
839+
}
840+
841+
if promptFrame == nil {
842+
t.Fatalf("No prompt call event")
843+
}
844+
845+
if promptFrame.Sensitive {
846+
t.Errorf("Unexpected sensitive prompt event: %v", promptFrame.Sensitive)
847+
}
848+
849+
if !strings.Contains(promptFrame.Message, "first name") {
850+
t.Errorf("unexpected confirm input: %s", promptFrame.Message)
851+
}
852+
853+
if len(promptFrame.Fields) != 1 {
854+
t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields))
855+
}
856+
857+
if promptFrame.Fields[0] != "first name" {
858+
t.Errorf("Unexpected field: %s", promptFrame.Fields[0])
859+
}
860+
861+
if err = c.PromptResponse(context.Background(), PromptResponse{
862+
ID: promptFrame.ID,
863+
Response: map[string]string{promptFrame.Fields[0]: "Clicky"},
864+
}); err != nil {
865+
t.Errorf("Error responding: %v", err)
866+
}
867+
868+
// Read the remainder of the events
869+
for e := range run.Events() {
870+
if e.Call != nil {
871+
for _, o := range e.Call.Output {
872+
eventContent += o.Content
873+
}
874+
}
875+
}
876+
877+
out, err := run.Text()
878+
if err != nil {
879+
t.Errorf("Error reading output: %v", err)
880+
}
881+
882+
if !strings.Contains(eventContent, "Clicky") {
883+
t.Errorf("Unexpected event output: %s", eventContent)
884+
}
885+
886+
if !strings.Contains(out, "Hello") || !strings.Contains(out, "Clicky") {
887+
t.Errorf("Unexpected output: %s", out)
888+
}
889+
890+
if len(run.ErrorOutput()) != 0 {
891+
t.Errorf("Should have no stderr output: %v", run.ErrorOutput())
892+
}
893+
}
894+
811895
func TestGetCommand(t *testing.T) {
812896
currentEnvVar := os.Getenv("GPTSCRIPT_BIN")
813897
t.Cleanup(func() {

frame.go

+33-25
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
package gptscript
22

3-
import (
4-
"time"
3+
import "time"
4+
5+
type ToolCategory string
6+
7+
type EventType string
8+
9+
const (
10+
CredentialToolCategory ToolCategory = "credential"
11+
ContextToolCategory ToolCategory = "context"
12+
NoCategory ToolCategory = ""
13+
14+
EventTypeRunStart EventType = "runStart"
15+
EventTypeCallStart EventType = "callStart"
16+
EventTypeCallContinue EventType = "callContinue"
17+
EventTypeCallSubCalls EventType = "callSubCalls"
18+
EventTypeCallProgress EventType = "callProgress"
19+
EventTypeChat EventType = "callChat"
20+
EventTypeCallConfirm EventType = "callConfirm"
21+
EventTypeCallFinish EventType = "callFinish"
22+
EventTypeRunFinish EventType = "runFinish"
23+
24+
EventTypePrompt EventType = "prompt"
525
)
626

727
type Frame struct {
8-
Run *RunFrame `json:"run,omitempty"`
9-
Call *CallFrame `json:"call,omitempty"`
28+
Run *RunFrame `json:"run,omitempty"`
29+
Call *CallFrame `json:"call,omitempty"`
30+
Prompt *PromptFrame `json:"prompt,omitempty"`
1031
}
1132

1233
type RunFrame struct {
@@ -74,24 +95,11 @@ type InputContext struct {
7495
Content string `json:"content,omitempty"`
7596
}
7697

77-
type ToolCategory string
78-
79-
const (
80-
CredentialToolCategory ToolCategory = "credential"
81-
ContextToolCategory ToolCategory = "context"
82-
NoCategory ToolCategory = ""
83-
)
84-
85-
type EventType string
86-
87-
const (
88-
EventTypeRunStart EventType = "runStart"
89-
EventTypeCallStart EventType = "callStart"
90-
EventTypeCallContinue EventType = "callContinue"
91-
EventTypeCallSubCalls EventType = "callSubCalls"
92-
EventTypeCallProgress EventType = "callProgress"
93-
EventTypeChat EventType = "callChat"
94-
EventTypeCallConfirm EventType = "callConfirm"
95-
EventTypeCallFinish EventType = "callFinish"
96-
EventTypeRunFinish EventType = "runFinish"
97-
)
98+
type PromptFrame struct {
99+
ID string `json:"id,omitempty"`
100+
Type EventType `json:"type,omitempty"`
101+
Time time.Time `json:"time,omitempty"`
102+
Message string `json:"message,omitempty"`
103+
Fields []string `json:"fields,omitempty"`
104+
Sensitive bool `json:"sensitive,omitempty"`
105+
}

prompt.go

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package gptscript
2+
3+
type PromptResponse struct {
4+
ID string `json:"id,omitempty"`
5+
Response map[string]string `json:"response,omitempty"`
6+
}

0 commit comments

Comments
 (0)