Skip to content

Commit 39497c0

Browse files
authored
Merge pull request #23 from thedadams/collect-calls
Collect calls
2 parents 8452bda + e645b06 commit 39497c0

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

client_test.go

+4
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ func TestSimpleEvaluate(t *testing.T) {
113113
if !strings.Contains(out, "Washington") {
114114
t.Errorf("Unexpected output: %s", out)
115115
}
116+
117+
if run.Program() == nil {
118+
t.Error("Run program not set")
119+
}
116120
}
117121

118122
func TestEvaluateWithContext(t *testing.T) {

run.go

+73-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"log/slog"
11+
"maps"
1112
"net/http"
1213
"os/exec"
1314
"strconv"
@@ -26,10 +27,14 @@ type Run struct {
2627
wait func()
2728
basicCommand bool
2829

29-
rawOutput map[string]any
30-
output, errput string
31-
events chan Frame
32-
lock sync.Mutex
30+
program *Program
31+
callsLock sync.RWMutex
32+
calls map[string]CallFrame
33+
parentCallFrameID string
34+
rawOutput map[string]any
35+
output, errput string
36+
events chan Frame
37+
lock sync.Mutex
3338
}
3439

3540
// Text returns the text output of the gptscript. It blocks until the output is ready.
@@ -59,6 +64,49 @@ func (r *Run) Err() error {
5964
return r.err
6065
}
6166

67+
// Program returns the gptscript program for the run.
68+
func (r *Run) Program() *Program {
69+
r.lock.Lock()
70+
defer r.lock.Unlock()
71+
return r.program
72+
}
73+
74+
// RespondingTool returns the name of the tool that produced the output.
75+
func (r *Run) RespondingTool() Tool {
76+
r.lock.Lock()
77+
defer r.lock.Unlock()
78+
79+
if r.program == nil {
80+
return Tool{}
81+
}
82+
83+
s, ok := r.rawOutput["toolID"].(string)
84+
if !ok {
85+
return Tool{}
86+
}
87+
88+
return r.program.ToolSet[s]
89+
}
90+
91+
// Calls will return a flattened array of the calls for this run.
92+
func (r *Run) Calls() map[string]CallFrame {
93+
r.callsLock.RLock()
94+
defer r.callsLock.RUnlock()
95+
return maps.Clone(r.calls)
96+
}
97+
98+
// ParentCallFrame returns the CallFrame for the top-level or "parent" call. The boolean indicates whether there is a parent CallFrame.
99+
func (r *Run) ParentCallFrame() (CallFrame, bool) {
100+
r.callsLock.RLock()
101+
defer r.callsLock.RUnlock()
102+
103+
if r.parentCallFrameID == "" {
104+
return CallFrame{}, false
105+
}
106+
107+
return r.calls[r.parentCallFrameID], true
108+
}
109+
62110
// ErrorOutput returns the stderr output of the gptscript.
63111
// Should only be called after Bytes or Text has returned an error.
64112
func (r *Run) ErrorOutput() string {
@@ -143,6 +191,10 @@ func (r *Run) NextChat(ctx context.Context, input string) (*Run, error) {
143191
}
144192

145193
func (r *Run) request(ctx context.Context, payload any) (err error) {
194+
if r.state.IsTerminal() {
195+
return fmt.Errorf("run is in terminal state and cannot be run again: state %q", r.state)
196+
}
197+
146198
var (
147199
req *http.Request
148200
url = fmt.Sprintf("%s/%s", r.url, r.requestPath)
@@ -205,6 +257,10 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
205257
r.lock.Unlock()
206258
}()
207259

260+
r.callsLock.Lock()
261+
r.calls = make(map[string]CallFrame)
262+
r.callsLock.Unlock()
263+
208264
for n := 0; n != 0 || err == nil; n, err = resp.Body.Read(buf) {
209265
for _, line := range bytes.Split(bytes.TrimSpace(append(frag, buf[:n]...)), []byte("\n\n")) {
210266
line = bytes.TrimSpace(bytes.TrimPrefix(line, []byte("data: ")))
@@ -287,6 +343,19 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
287343
return
288344
}
289345

346+
if event.Call != nil {
347+
r.callsLock.Lock()
348+
r.calls[event.Call.ID] = *event.Call
349+
if r.parentCallFrameID == "" && event.Call.ParentID == "" {
350+
r.parentCallFrameID = event.Call.ID
351+
}
352+
r.callsLock.Unlock()
353+
} else if event.Run != nil && event.Run.Type == EventTypeRunStart {
354+
r.callsLock.Lock()
355+
r.program = &event.Run.Program
356+
r.callsLock.Unlock()
357+
}
358+
290359
if r.opts.IncludeEvents {
291360
r.events <- event
292361
}

0 commit comments

Comments
 (0)