diff --git a/client_test.go b/client_test.go index 89b1343..08bfb4f 100644 --- a/client_test.go +++ b/client_test.go @@ -113,6 +113,10 @@ func TestSimpleEvaluate(t *testing.T) { if !strings.Contains(out, "Washington") { t.Errorf("Unexpected output: %s", out) } + + if run.Program() == nil { + t.Error("Run program not set") + } } func TestEvaluateWithContext(t *testing.T) { diff --git a/run.go b/run.go index cc820a3..98e5720 100644 --- a/run.go +++ b/run.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "log/slog" + "maps" "net/http" "os/exec" "strconv" @@ -26,10 +27,14 @@ type Run struct { wait func() basicCommand bool - rawOutput map[string]any - output, errput string - events chan Frame - lock sync.Mutex + program *Program + callsLock sync.RWMutex + calls map[string]CallFrame + parentCallFrameID string + rawOutput map[string]any + output, errput string + events chan Frame + lock sync.Mutex } // Text returns the text output of the gptscript. It blocks until the output is ready. @@ -59,6 +64,49 @@ func (r *Run) Err() error { return r.err } +// Program returns the gptscript program for the run. +func (r *Run) Program() *Program { + r.lock.Lock() + defer r.lock.Unlock() + return r.program +} + +// RespondingTool returns the name of the tool that produced the output. +func (r *Run) RespondingTool() Tool { + r.lock.Lock() + defer r.lock.Unlock() + + if r.program == nil { + return Tool{} + } + + s, ok := r.rawOutput["toolID"].(string) + if !ok { + return Tool{} + } + + return r.program.ToolSet[s] +} + +// Calls will return a flattened array of the calls for this run. +func (r *Run) Calls() map[string]CallFrame { + r.callsLock.RLock() + defer r.callsLock.RUnlock() + return maps.Clone(r.calls) +} + +// ParentCallFrame returns the CallFrame for the top-level or "parent" call. The boolean indicates whether there is a parent CallFrame. +func (r *Run) ParentCallFrame() (CallFrame, bool) { + r.callsLock.RLock() + defer r.callsLock.RUnlock() + + if r.parentCallFrameID == "" { + return CallFrame{}, false + } + + return r.calls[r.parentCallFrameID], true +} + // ErrorOutput returns the stderr output of the gptscript. // Should only be called after Bytes or Text has returned an error. func (r *Run) ErrorOutput() string { @@ -143,6 +191,10 @@ func (r *Run) NextChat(ctx context.Context, input string) (*Run, error) { } func (r *Run) request(ctx context.Context, payload any) (err error) { + if r.state.IsTerminal() { + return fmt.Errorf("run is in terminal state and cannot be run again: state %q", r.state) + } + var ( req *http.Request url = fmt.Sprintf("%s/%s", r.url, r.requestPath) @@ -205,6 +257,10 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { r.lock.Unlock() }() + r.callsLock.Lock() + r.calls = make(map[string]CallFrame) + r.callsLock.Unlock() + for n := 0; n != 0 || err == nil; n, err = resp.Body.Read(buf) { for _, line := range bytes.Split(bytes.TrimSpace(append(frag, buf[:n]...)), []byte("\n\n")) { line = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data: "))) @@ -287,6 +343,19 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { return } + if event.Call != nil { + r.callsLock.Lock() + r.calls[event.Call.ID] = *event.Call + if r.parentCallFrameID == "" && event.Call.ParentID == "" { + r.parentCallFrameID = event.Call.ID + } + r.callsLock.Unlock() + } else if event.Run != nil && event.Run.Type == EventTypeRunStart { + r.callsLock.Lock() + r.program = &event.Run.Program + r.callsLock.Unlock() + } + if r.opts.IncludeEvents { r.events <- event }