Skip to content

Commit ff48fd2

Browse files
authored
enhance: Add proper aborting of runs (#94)
Aborting a run is different from "closing" it. Closing a run will result in an error. Aborting a run will cause it to stop at the next available event and not return any error. Instead, the run will have its text appended with "ABORTED BY USER" and all the chat state will be preserved. Signed-off-by: Donnie Adams <[email protected]>
1 parent eee4337 commit ff48fd2

File tree

3 files changed

+144
-2
lines changed

3 files changed

+144
-2
lines changed

gptscript.go

+5
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru
170170
}).NextChat(ctx, opts.Input)
171171
}
172172

173+
func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error {
174+
_, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil))
175+
return err
176+
}
177+
173178
type ParseOptions struct {
174179
DisableCache bool
175180
}

gptscript_test.go

+137-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"strconv"
1212
"strings"
1313
"testing"
14+
"time"
1415

1516
"github.com/getkin/kin-openapi/openapi3"
1617
"github.com/stretchr/testify/require"
@@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) {
134135
}
135136
}
136137

137-
func TestAbortRun(t *testing.T) {
138+
func TestCancelRun(t *testing.T) {
138139
tool := ToolDef{Instructions: "What is the capital of the united states?"}
139140

140141
run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
@@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) {
146147
<-run.Events()
147148

148149
if err := run.Close(); err != nil {
149-
t.Errorf("Error aborting run: %v", err)
150+
t.Errorf("Error canceling run: %v", err)
150151
}
151152

152153
if run.State() != Error {
@@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) {
158159
}
159160
}
160161

162+
func TestAbortChatCompletionRun(t *testing.T) {
163+
tool := ToolDef{Instructions: "What is the capital of the united states?"}
164+
165+
run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
166+
if err != nil {
167+
t.Errorf("Error executing tool: %v", err)
168+
}
169+
170+
// Abort the run after the first event from the LLM
171+
for e := range run.Events() {
172+
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
173+
break
174+
}
175+
}
176+
177+
if err := g.AbortRun(context.Background(), run); err != nil {
178+
t.Errorf("Error aborting run: %v", err)
179+
}
180+
181+
// Wait for run to stop
182+
for range run.Events() {
183+
continue
184+
}
185+
186+
if run.State() != Finished {
187+
t.Errorf("Unexpected run state: %s", run.State())
188+
}
189+
190+
if out, err := run.Text(); err != nil {
191+
t.Errorf("Error reading output: %v", err)
192+
} else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
193+
t.Errorf("Unexpected output: %s", out)
194+
}
195+
}
196+
197+
func TestAbortCommandRun(t *testing.T) {
198+
tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"}
199+
200+
run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
201+
if err != nil {
202+
t.Errorf("Error executing tool: %v", err)
203+
}
204+
205+
// Abort the run after the first event.
206+
for e := range run.Events() {
207+
if e.Call != nil && e.Call.Type == EventTypeChat {
208+
time.Sleep(2 * time.Second)
209+
break
210+
}
211+
}
212+
213+
if err := g.AbortRun(context.Background(), run); err != nil {
214+
t.Errorf("Error aborting run: %v", err)
215+
}
216+
217+
// Wait for run to stop
218+
for range run.Events() {
219+
continue
220+
}
221+
222+
if run.State() != Finished {
223+
t.Errorf("Unexpected run state: %s", run.State())
224+
}
225+
226+
if out, err := run.Text(); err != nil {
227+
t.Errorf("Error reading output: %v", err)
228+
} else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") {
229+
t.Errorf("Unexpected output: %s", out)
230+
}
231+
}
232+
161233
func TestSimpleEvaluate(t *testing.T) {
162234
tool := ToolDef{Instructions: "What is the capital of the united states?"}
163235

@@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) {
844916
}
845917
}
846918

919+
func TestAbortChat(t *testing.T) {
920+
tool := ToolDef{
921+
Chat: true,
922+
Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.",
923+
Tools: []string{"sys.chat.finish"},
924+
}
925+
926+
run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
927+
if err != nil {
928+
t.Fatalf("Error executing tool: %v", err)
929+
}
930+
inputs := []string{
931+
"Tell me a joke.",
932+
"What was my first message?",
933+
}
934+
935+
// Just wait for the chat to start up.
936+
for range run.Events() {
937+
continue
938+
}
939+
940+
for i, input := range inputs {
941+
run, err = run.NextChat(context.Background(), input)
942+
if err != nil {
943+
t.Fatalf("Error sending next input %q: %v", input, err)
944+
}
945+
946+
// Abort the run after the first event from the LLM
947+
for e := range run.Events() {
948+
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
949+
break
950+
}
951+
}
952+
953+
if i == 0 {
954+
if err := g.AbortRun(context.Background(), run); err != nil {
955+
t.Fatalf("Error aborting run: %v", err)
956+
}
957+
}
958+
959+
// Wait for the run to complete
960+
for range run.Events() {
961+
continue
962+
}
963+
964+
out, err := run.Text()
965+
if err != nil {
966+
t.Errorf("Error reading output: %s", run.ErrorOutput())
967+
t.Fatalf("Error reading output: %v", err)
968+
}
969+
970+
if i == 0 {
971+
if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
972+
t.Fatalf("Unexpected output: %s", out)
973+
}
974+
} else {
975+
if !strings.Contains(out, "Tell me a joke") {
976+
t.Errorf("Unexpected output: %s", out)
977+
}
978+
}
979+
}
980+
}
981+
847982
func TestFileChat(t *testing.T) {
848983
wd, err := os.Getwd()
849984
if err != nil {

run.go

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ type Run struct {
3737
basicCommand bool
3838

3939
program *Program
40+
id string
4041
callsLock sync.RWMutex
4142
calls CallFrames
4243
rawOutput map[string]any
@@ -400,6 +401,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
400401
if event.Run.Type == EventTypeRunStart {
401402
r.callsLock.Lock()
402403
r.program = &event.Run.Program
404+
r.id = event.Run.ID
403405
r.callsLock.Unlock()
404406
} else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" {
405407
r.state = Error

0 commit comments

Comments
 (0)