diff --git a/llms/bedrock/bedrockllm_test.go b/llms/bedrock/bedrockllm_test.go index 2df294d8b..19c71d1ac 100644 --- a/llms/bedrock/bedrockllm_test.go +++ b/llms/bedrock/bedrockllm_test.go @@ -40,13 +40,13 @@ func TestAmazonOutput(t *testing.T) { { Role: llms.ChatMessageTypeSystem, Parts: []llms.ContentPart{ - llms.TextPart("You know all about AI."), + llms.TextPart("You know all about AI.", false), }, }, { Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ - llms.TextPart("Explain AI in 10 words or less."), + llms.TextPart("Explain AI in 10 words or less.", false), }, }, } diff --git a/llms/generatecontent.go b/llms/generatecontent.go index 8702143b0..e2d25153b 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -17,7 +17,10 @@ type MessageContent struct { } // TextPart creates TextContent from a given string. -func TextPart(s string) TextContent { +func TextPart(s string, isCacheable bool) TextContent { + if isCacheable { + return TextContent{Text: s, CacheControl: CacheControl{Type: "ephemeral"}} + } return TextContent{Text: s} } @@ -52,7 +55,12 @@ type ContentPart interface { // TextContent is content with some text. type TextContent struct { - Text string + Text string + CacheControl CacheControl `json:"cache_control"` +} + +type CacheControl struct { + Type string `json:"type"` } func (tc TextContent) String() string { @@ -147,13 +155,13 @@ type ContentChoice struct { // TextParts is a helper function to create a MessageContent with a role and a // list of text parts. -func TextParts(role ChatMessageType, parts ...string) MessageContent { +func TextParts(role ChatMessageType, isCacheable bool, parts ...string) MessageContent { result := MessageContent{ Role: role, Parts: []ContentPart{}, } for _, part := range parts { - result.Parts = append(result.Parts, TextPart(part)) + result.Parts = append(result.Parts, TextPart(part, isCacheable)) } return result } diff --git a/llms/generatecontent_test.go b/llms/generatecontent_test.go index 71ce77a9f..2257733e9 100644 --- a/llms/generatecontent_test.go +++ b/llms/generatecontent_test.go @@ -28,7 +28,7 @@ func TestTextParts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - if got := TextParts(tt.args.role, tt.args.parts...); !reflect.DeepEqual(got, tt.want) { + if got := TextParts(tt.args.role, false, tt.args.parts...); !reflect.DeepEqual(got, tt.want) { t.Errorf("TextParts() = %v, want %v", got, tt.want) } }) diff --git a/llms/googleai/shared_test/shared_test.go b/llms/googleai/shared_test/shared_test.go index de72e2cca..6e7ab977b 100644 --- a/llms/googleai/shared_test/shared_test.go +++ b/llms/googleai/shared_test/shared_test.go @@ -125,8 +125,8 @@ func testMultiContentText(t *testing.T, llm llms.Model) { t.Parallel() parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("What kind of mammal am I?"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("What kind of mammal am I?", false), } content := []llms.MessageContent{ { @@ -151,6 +151,7 @@ func testMultiContentTextUsingTextParts(t *testing.T, llm llms.Model) { content := llms.TextParts( llms.ChatMessageTypeHuman, + false, "I'm a pomeranian", "What kind of mammal am I?", ) @@ -181,15 +182,15 @@ func testMultiContentTextChatSequence(t *testing.T, llm llms.Model) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name some countries")}, + Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)}, }, { Role: llms.ChatMessageTypeAI, - Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")}, + Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")}, + Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)}, }, } @@ -208,11 +209,11 @@ func testMultiContentWithSystemMessage(t *testing.T, llm llms.Model) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish")}, + Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits")}, + Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits", false)}, }, } @@ -232,7 +233,7 @@ func testMultiContentImageLink(t *testing.T, llm llms.Model) { llms.ImageURLPart( "https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true", ), - llms.TextPart("describe this image in detail"), + llms.TextPart("describe this image in detail", false), } content := []llms.MessageContent{ { @@ -264,7 +265,7 @@ func testMultiContentImageBinary(t *testing.T, llm llms.Model) { parts := []llms.ContentPart{ llms.BinaryPart("image/png", b), - llms.TextPart("what does this image show? please use detail"), + llms.TextPart("what does this image show? please use detail", false), } content := []llms.MessageContent{ { @@ -304,7 +305,7 @@ func testCandidateCountSetting(t *testing.T, llm llms.Model) { t.Helper() parts := []llms.ContentPart{ - llms.TextPart("Name five countries in Africa"), + llms.TextPart("Name five countries in Africa", false), } content := []llms.MessageContent{ { @@ -330,6 +331,7 @@ func testWithStreaming(t *testing.T, llm llms.Model) { content := llms.TextParts( llms.ChatMessageTypeHuman, + false, "I'm a pomeranian", "Tell me more about my taxonomy", ) @@ -376,7 +378,7 @@ func testTools(t *testing.T, llm llms.Model) { } content := []llms.MessageContent{ - llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"), + llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"), } resp, err := llm.GenerateContent( context.Background(), @@ -459,7 +461,7 @@ func testToolsWithInterfaceRequired(t *testing.T, llm llms.Model) { } content := []llms.MessageContent{ - llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"), + llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"), } resp, err := llm.GenerateContent( context.Background(), @@ -523,8 +525,8 @@ func testMaxTokensSetting(t *testing.T, llm llms.Model) { t.Parallel() parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("Describe my taxonomy, health and care"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("Describe my taxonomy, health and care", false), } content := []llms.MessageContent{ { @@ -566,7 +568,7 @@ func testWithHTTPClient(t *testing.T, llm llms.Model) { resp, err := llm.GenerateContent( context.TODO(), - []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "testing")}, + []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, false, "testing")}, ) require.NoError(t, err) require.EqualValues(t, "test-ok", resp.Choices[0].Content) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 2bb572a0a..e418043f6 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -73,7 +73,8 @@ type ChatRequest struct { FunctionCallBehavior FunctionCallBehavior `json:"function_call,omitempty"` // Metadata allows you to specify additional information that will be passed to the model. - Metadata map[string]any `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } // ToolType is the type of a tool. diff --git a/llms/openai/multicontent_test.go b/llms/openai/multicontent_test.go index 04bdfe139..863a38499 100644 --- a/llms/openai/multicontent_test.go +++ b/llms/openai/multicontent_test.go @@ -29,8 +29,8 @@ func TestMultiContentText(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("What kind of mammal am I?"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("What kind of mammal am I?", false), } content := []llms.MessageContent{ { @@ -54,15 +54,15 @@ func TestMultiContentTextChatSequence(t *testing.T) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name some countries")}, + Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)}, }, { Role: llms.ChatMessageTypeAI, - Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")}, + Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")}, + Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)}, }, } @@ -81,7 +81,7 @@ func TestMultiContentImage(t *testing.T) { parts := []llms.ContentPart{ llms.ImageURLPart("https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true"), - llms.TextPart("describe this image in detail"), + llms.TextPart("describe this image in detail", false), } content := []llms.MessageContent{ { @@ -103,8 +103,8 @@ func TestWithStreaming(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("Tell me more about my taxonomy"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("Tell me more about my taxonomy", false), } content := []llms.MessageContent{ { @@ -134,7 +134,7 @@ func TestFunctionCall(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("What is the weather like in Boston?"), + llms.TextPart("What is the weather like in Boston?", false), } content := []llms.MessageContent{ { diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..9b3c52433 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -111,6 +111,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), Seed: opts.Seed, Metadata: opts.Metadata, + ExtraHeaders: opts.ExtraHeaders, } if opts.JSONMode { req.ResponseFormat = ResponseFormatJSON diff --git a/llms/options.go b/llms/options.go index b6b595290..c68055eac 100644 --- a/llms/options.go +++ b/llms/options.go @@ -66,6 +66,9 @@ type CallOptions struct { // Supported MIME types are: text/plain: (default) Text output. // application/json: JSON response in the response candidates. ResponseMIMEType string `json:"response_mime_type,omitempty"` + + // ExtraHeaders is a map of extra headers to include in the request. + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } // Tool is a tool that can be used by the model. @@ -272,11 +275,3 @@ func WithMetadata(metadata map[string]interface{}) CallOption { o.Metadata = metadata } } - -// WithResponseMIMEType will add an option to set the ResponseMIMEType -// Currently only supported by googleai llms. -func WithResponseMIMEType(responseMIMEType string) CallOption { - return func(o *CallOptions) { - o.ResponseMIMEType = responseMIMEType - } -}