diff --git a/pkg/openai/client.go b/pkg/openai/client.go index db911962..295961b1 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -281,10 +281,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C chatMessage.ToolCalls = append(chatMessage.ToolCalls, toToolCall(*content.ToolCall)) } if content.Text != "" { - chatMessage.MultiContent = append(chatMessage.MultiContent, openai.ChatMessagePart{ - Type: openai.ChatMessagePartTypeText, - Text: content.Text, - }) + chatMessage.MultiContent = append(chatMessage.MultiContent, textToMultiContent(content.Text)...) } } @@ -306,6 +303,35 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C return } +const imagePrefix = "data:image/png;base64," + +func textToMultiContent(text string) []openai.ChatMessagePart { + var chatParts []openai.ChatMessagePart + parts := strings.Split(text, "\n") + for i := len(parts) - 1; i >= 0; i-- { + if strings.HasPrefix(parts[i], imagePrefix) { + chatParts = append(chatParts, openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: parts[i], + }, + }) + parts = parts[:i] + } else { + break + } + } + if len(parts) > 0 { + chatParts = append(chatParts, openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeText, + Text: strings.Join(parts, "\n"), + }) + } + + slices.Reverse(chatParts) + return chatParts +} + func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { if err := c.ValidAuth(); err != nil { if err := c.RetrieveAPIKey(ctx, env); err != nil { diff --git a/pkg/openai/client_test.go b/pkg/openai/client_test.go index 30f1705b..78f3eac2 100644 --- a/pkg/openai/client_test.go +++ b/pkg/openai/client_test.go @@ -9,6 +9,44 @@ import ( "github.com/hexops/valast" ) +func TestTextToMultiContent(t *testing.T) { + autogold.Expect([]openai.ChatMessagePart{{ + Type: "text", + Text: "hi\ndata:image/png;base64,xxxxx\n", + }}).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx\n")) + + autogold.Expect([]openai.ChatMessagePart{ + { + Type: "text", + Text: "hi", + }, + { + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"}, + }, + }).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx")) + + autogold.Expect([]openai.ChatMessagePart{{ + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"}, + }}).Equal(t, textToMultiContent("data:image/png;base64,xxxxx")) + + autogold.Expect([]openai.ChatMessagePart{ + { + Type: "text", + Text: "\none\ntwo", + }, + { + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"}, + }, + { + Type: "image_url", + ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,yyyyy"}, + }, + }).Equal(t, textToMultiContent("\none\ntwo\ndata:image/png;base64,xxxxx\ndata:image/png;base64,yyyyy")) +} + func Test_appendMessage(t *testing.T) { autogold.Expect(types.CompletionMessage{Content: []types.ContentPart{ {ToolCall: &types.CompletionToolCall{ diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e2699cf6..0b950059 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -651,6 +651,17 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher { return newParallelDispatcher(ctx) } +func idForToolCall(id string, state *engine.Return) string { + if state == nil || state.State == nil { + return id + } + tc, ok := state.State.Pending[id] + if !ok || tc.Index == nil { + return id + } + return fmt.Sprintf("%03d", *tc.Index) +} + func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) { var resultLock sync.Mutex @@ -693,7 +704,9 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, // Sort the id so if sequential the results are predictable ids := maps.Keys(state.Continuation.Calls) - sort.Strings(ids) + sort.Slice(ids, func(i, j int) bool { + return idForToolCall(ids[i], state.Continuation) < idForToolCall(ids[j], state.Continuation) + }) for _, id := range ids { call := state.Continuation.Calls[id]