Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add prompt support #21

Merged
merged 1 commit into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ const relativeToBinaryPath = "<me>"
type Client interface {
Run(context.Context, string, Options) (*Run, error)
Evaluate(context.Context, Options, ...fmt.Stringer) (*Run, error)
Parse(ctx context.Context, fileName string) ([]Node, error)
ParseTool(ctx context.Context, toolDef string) ([]Node, error)
Version(ctx context.Context) (string, error)
Fmt(ctx context.Context, nodes []Node) (string, error)
ListTools(ctx context.Context) (string, error)
ListModels(ctx context.Context) ([]string, error)
Confirm(ctx context.Context, resp AuthResponse) error
Parse(context.Context, string) ([]Node, error)
ParseTool(context.Context, string) ([]Node, error)
Version(context.Context) (string, error)
Fmt(context.Context, []Node) (string, error)
ListTools(context.Context) (string, error)
ListModels(context.Context) ([]string, error)
Confirm(context.Context, AuthResponse) error
PromptResponse(context.Context, PromptResponse) error
Close()
}

Expand Down Expand Up @@ -208,6 +209,11 @@ func (c *client) Confirm(ctx context.Context, resp AuthResponse) error {
return err
}

func (c *client) PromptResponse(ctx context.Context, resp PromptResponse) error {
_, err := c.runBasicCommand(ctx, "prompt-response/"+resp.ID, resp.Response)
return err
}

func (c *client) runBasicCommand(ctx context.Context, requestPath string, body any) (string, error) {
run := &Run{
url: c.gptscriptURL,
Expand Down
90 changes: 87 additions & 3 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func TestStreamEvaluate(t *testing.T) {

run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tool)
if err != nil {
t.Errorf("Error executing tool: %v", err)
t.Fatalf("Error executing tool: %v", err)
}

for e := range run.Events() {
Expand Down Expand Up @@ -297,7 +297,7 @@ func TestStreamRun(t *testing.T) {
var eventContent string
run, err := c.Run(context.Background(), wd+"/test/catcher.gpt", Options{IncludeEvents: true})
if err != nil {
t.Errorf("Error executing file: %v", err)
t.Fatalf("Error executing file: %v", err)
}

for e := range run.Events() {
Expand Down Expand Up @@ -618,7 +618,7 @@ func TestToolWithGlobalTools(t *testing.T) {

run, err := c.Run(context.Background(), wd+"/test/global-tools.gpt", Options{DisableCache: true, IncludeEvents: true})
if err != nil {
t.Errorf("Error executing tool: %v", err)
t.Fatalf("Error executing tool: %v", err)
}

for e := range run.Events() {
Expand Down Expand Up @@ -808,6 +808,90 @@ func TestConfirmDeny(t *testing.T) {
}
}

func TestPrompt(t *testing.T) {
var eventContent string
tools := []fmt.Stringer{
&ToolDef{
Instructions: "Use the sys.prompt user to ask the user for 'first name' which is not sensitive. After you get their first name, say hello.",
Tools: []string{"sys.prompt"},
},
}

run, err := c.Evaluate(context.Background(), Options{IncludeEvents: true}, tools...)
if err != nil {
t.Errorf("Error executing tool: %v", err)
}

// Wait for the prompt event
var promptFrame *PromptFrame
for e := range run.Events() {
if e.Call != nil {
for _, o := range e.Call.Output {
eventContent += o.Content
}
}
if e.Prompt != nil {
if e.Prompt.Type == EventTypePrompt {
promptFrame = e.Prompt
break
}
}
}

if promptFrame == nil {
t.Fatalf("No prompt call event")
}

if promptFrame.Sensitive {
t.Errorf("Unexpected sensitive prompt event: %v", promptFrame.Sensitive)
}

if !strings.Contains(promptFrame.Message, "first name") {
t.Errorf("unexpected confirm input: %s", promptFrame.Message)
}

if len(promptFrame.Fields) != 1 {
t.Fatalf("Unexpected number of fields: %d", len(promptFrame.Fields))
}

if promptFrame.Fields[0] != "first name" {
t.Errorf("Unexpected field: %s", promptFrame.Fields[0])
}

if err = c.PromptResponse(context.Background(), PromptResponse{
ID: promptFrame.ID,
Response: map[string]string{promptFrame.Fields[0]: "Clicky"},
}); err != nil {
t.Errorf("Error responding: %v", err)
}

// Read the remainder of the events
for e := range run.Events() {
if e.Call != nil {
for _, o := range e.Call.Output {
eventContent += o.Content
}
}
}

out, err := run.Text()
if err != nil {
t.Errorf("Error reading output: %v", err)
}

if !strings.Contains(eventContent, "Clicky") {
t.Errorf("Unexpected event output: %s", eventContent)
}

if !strings.Contains(out, "Hello") || !strings.Contains(out, "Clicky") {
t.Errorf("Unexpected output: %s", out)
}

if len(run.ErrorOutput()) != 0 {
t.Errorf("Should have no stderr output: %v", run.ErrorOutput())
}
}

func TestGetCommand(t *testing.T) {
currentEnvVar := os.Getenv("GPTSCRIPT_BIN")
t.Cleanup(func() {
Expand Down
58 changes: 33 additions & 25 deletions frame.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,33 @@
package gptscript

import (
"time"
import "time"

type ToolCategory string

type EventType string

const (
CredentialToolCategory ToolCategory = "credential"
ContextToolCategory ToolCategory = "context"
NoCategory ToolCategory = ""

EventTypeRunStart EventType = "runStart"
EventTypeCallStart EventType = "callStart"
EventTypeCallContinue EventType = "callContinue"
EventTypeCallSubCalls EventType = "callSubCalls"
EventTypeCallProgress EventType = "callProgress"
EventTypeChat EventType = "callChat"
EventTypeCallConfirm EventType = "callConfirm"
EventTypeCallFinish EventType = "callFinish"
EventTypeRunFinish EventType = "runFinish"

EventTypePrompt EventType = "prompt"
)

type Frame struct {
Run *RunFrame `json:"run,omitempty"`
Call *CallFrame `json:"call,omitempty"`
Run *RunFrame `json:"run,omitempty"`
Call *CallFrame `json:"call,omitempty"`
Prompt *PromptFrame `json:"prompt,omitempty"`
}

type RunFrame struct {
Expand Down Expand Up @@ -74,24 +95,11 @@ type InputContext struct {
Content string `json:"content,omitempty"`
}

type ToolCategory string

const (
CredentialToolCategory ToolCategory = "credential"
ContextToolCategory ToolCategory = "context"
NoCategory ToolCategory = ""
)

type EventType string

const (
EventTypeRunStart EventType = "runStart"
EventTypeCallStart EventType = "callStart"
EventTypeCallContinue EventType = "callContinue"
EventTypeCallSubCalls EventType = "callSubCalls"
EventTypeCallProgress EventType = "callProgress"
EventTypeChat EventType = "callChat"
EventTypeCallConfirm EventType = "callConfirm"
EventTypeCallFinish EventType = "callFinish"
EventTypeRunFinish EventType = "runFinish"
)
type PromptFrame struct {
ID string `json:"id,omitempty"`
Type EventType `json:"type,omitempty"`
Time time.Time `json:"time,omitempty"`
Message string `json:"message,omitempty"`
Fields []string `json:"fields,omitempty"`
Sensitive bool `json:"sensitive,omitempty"`
}
6 changes: 6 additions & 0 deletions prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package gptscript

type PromptResponse struct {
ID string `json:"id,omitempty"`
Response map[string]string `json:"response,omitempty"`
}