@@ -11,6 +11,7 @@ import (
11
11
"strconv"
12
12
"strings"
13
13
"testing"
14
+ "time"
14
15
15
16
"github.com/getkin/kin-openapi/openapi3"
16
17
"github.com/stretchr/testify/require"
@@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) {
134
135
}
135
136
}
136
137
137
- func TestAbortRun (t * testing.T ) {
138
+ func TestCancelRun (t * testing.T ) {
138
139
tool := ToolDef {Instructions : "What is the capital of the united states?" }
139
140
140
141
run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
@@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) {
146
147
<- run .Events ()
147
148
148
149
if err := run .Close (); err != nil {
149
- t .Errorf ("Error aborting run: %v" , err )
150
+ t .Errorf ("Error canceling run: %v" , err )
150
151
}
151
152
152
153
if run .State () != Error {
@@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) {
158
159
}
159
160
}
160
161
162
+ func TestAbortChatCompletionRun (t * testing.T ) {
163
+ tool := ToolDef {Instructions : "What is the capital of the united states?" }
164
+
165
+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
166
+ if err != nil {
167
+ t .Errorf ("Error executing tool: %v" , err )
168
+ }
169
+
170
+ // Abort the run after the first event from the LLM
171
+ for e := range run .Events () {
172
+ if e .Call != nil && e .Call .Type == EventTypeCallProgress && len (e .Call .Output ) > 0 && e .Call .Output [0 ].Content != "Waiting for model response..." {
173
+ break
174
+ }
175
+ }
176
+
177
+ if err := g .AbortRun (context .Background (), run ); err != nil {
178
+ t .Errorf ("Error aborting run: %v" , err )
179
+ }
180
+
181
+ // Wait for run to stop
182
+ for range run .Events () {
183
+ continue
184
+ }
185
+
186
+ if run .State () != Finished {
187
+ t .Errorf ("Unexpected run state: %s" , run .State ())
188
+ }
189
+
190
+ if out , err := run .Text (); err != nil {
191
+ t .Errorf ("Error reading output: %v" , err )
192
+ } else if strings .TrimSpace (out ) != "ABORTED BY USER" && ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
193
+ t .Errorf ("Unexpected output: %s" , out )
194
+ }
195
+ }
196
+
197
+ func TestAbortCommandRun (t * testing.T ) {
198
+ tool := ToolDef {Instructions : "#!/usr/bin/env bash\n echo Hello, world!\n sleep 5\n echo Hello, again!\n sleep 5" }
199
+
200
+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
201
+ if err != nil {
202
+ t .Errorf ("Error executing tool: %v" , err )
203
+ }
204
+
205
+ // Abort the run after the first event.
206
+ for e := range run .Events () {
207
+ if e .Call != nil && e .Call .Type == EventTypeChat {
208
+ time .Sleep (2 * time .Second )
209
+ break
210
+ }
211
+ }
212
+
213
+ if err := g .AbortRun (context .Background (), run ); err != nil {
214
+ t .Errorf ("Error aborting run: %v" , err )
215
+ }
216
+
217
+ // Wait for run to stop
218
+ for range run .Events () {
219
+ continue
220
+ }
221
+
222
+ if run .State () != Finished {
223
+ t .Errorf ("Unexpected run state: %s" , run .State ())
224
+ }
225
+
226
+ if out , err := run .Text (); err != nil {
227
+ t .Errorf ("Error reading output: %v" , err )
228
+ } else if ! strings .Contains (out , "Hello, world!" ) || strings .Contains (out , "Hello, again!" ) || ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
229
+ t .Errorf ("Unexpected output: %s" , out )
230
+ }
231
+ }
232
+
161
233
func TestSimpleEvaluate (t * testing.T ) {
162
234
tool := ToolDef {Instructions : "What is the capital of the united states?" }
163
235
@@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) {
844
916
}
845
917
}
846
918
919
+ func TestAbortChat (t * testing.T ) {
920
+ tool := ToolDef {
921
+ Chat : true ,
922
+ Instructions : "You are a chat bot. Don't finish the conversation until I say 'bye'." ,
923
+ Tools : []string {"sys.chat.finish" },
924
+ }
925
+
926
+ run , err := g .Evaluate (context .Background (), Options {DisableCache : true , IncludeEvents : true }, tool )
927
+ if err != nil {
928
+ t .Fatalf ("Error executing tool: %v" , err )
929
+ }
930
+ inputs := []string {
931
+ "Tell me a joke." ,
932
+ "What was my first message?" ,
933
+ }
934
+
935
+ // Just wait for the chat to start up.
936
+ for range run .Events () {
937
+ continue
938
+ }
939
+
940
+ for i , input := range inputs {
941
+ run , err = run .NextChat (context .Background (), input )
942
+ if err != nil {
943
+ t .Fatalf ("Error sending next input %q: %v" , input , err )
944
+ }
945
+
946
+ // Abort the run after the first event from the LLM
947
+ for e := range run .Events () {
948
+ if e .Call != nil && e .Call .Type == EventTypeCallProgress && len (e .Call .Output ) > 0 && e .Call .Output [0 ].Content != "Waiting for model response..." {
949
+ break
950
+ }
951
+ }
952
+
953
+ if i == 0 {
954
+ if err := g .AbortRun (context .Background (), run ); err != nil {
955
+ t .Fatalf ("Error aborting run: %v" , err )
956
+ }
957
+ }
958
+
959
+ // Wait for the run to complete
960
+ for range run .Events () {
961
+ continue
962
+ }
963
+
964
+ out , err := run .Text ()
965
+ if err != nil {
966
+ t .Errorf ("Error reading output: %s" , run .ErrorOutput ())
967
+ t .Fatalf ("Error reading output: %v" , err )
968
+ }
969
+
970
+ if i == 0 {
971
+ if strings .TrimSpace (out ) != "ABORTED BY USER" && ! strings .HasSuffix (out , "\n ABORTED BY USER" ) {
972
+ t .Fatalf ("Unexpected output: %s" , out )
973
+ }
974
+ } else {
975
+ if ! strings .Contains (out , "Tell me a joke" ) {
976
+ t .Errorf ("Unexpected output: %s" , out )
977
+ }
978
+ }
979
+ }
980
+ }
981
+
847
982
func TestFileChat (t * testing.T ) {
848
983
wd , err := os .Getwd ()
849
984
if err != nil {
0 commit comments