From 5b6e33a659e5afcfc6ec957dedd08f0fb1c67532 Mon Sep 17 00:00:00 2001 From: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> Date: Fri, 23 Aug 2024 01:31:56 +0800 Subject: [PATCH 1/4] add cache control --- llms/generatecontent.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llms/generatecontent.go b/llms/generatecontent.go index 8702143b0..cf2745f8e 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -17,7 +17,10 @@ type MessageContent struct { } // TextPart creates TextContent from a given string. -func TextPart(s string) TextContent { +func TextPart(s string, isCacheable bool) TextContent { + if isCacheable { + return TextContent{Text: s, CacheControl: CacheControl{Type: "ephemeral"}} + } return TextContent{Text: s} } @@ -52,7 +55,12 @@ type ContentPart interface { // TextContent is content with some text. type TextContent struct { - Text string + Text string + CacheControl CacheControl `json:"cache_control"` +} + +type CacheControl struct { + Type string `json:"type"` } func (tc TextContent) String() string { From 9d75a77b72ebfd802e121b4f68afbf5f17b2113c Mon Sep 17 00:00:00 2001 From: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> Date: Fri, 23 Aug 2024 01:41:10 +0800 Subject: [PATCH 2/4] add isCacheable to text parts --- llms/generatecontent.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/generatecontent.go b/llms/generatecontent.go index cf2745f8e..e2d25153b 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -155,13 +155,13 @@ type ContentChoice struct { // TextParts is a helper function to create a MessageContent with a role and a // list of text parts. -func TextParts(role ChatMessageType, parts ...string) MessageContent { +func TextParts(role ChatMessageType, isCacheable bool, parts ...string) MessageContent { result := MessageContent{ Role: role, Parts: []ContentPart{}, } for _, part := range parts { - result.Parts = append(result.Parts, TextPart(part)) + result.Parts = append(result.Parts, TextPart(part, isCacheable)) } return result } From 9fafd43373986479c75606f768392b32e93f018d Mon Sep 17 00:00:00 2001 From: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> Date: Fri, 23 Aug 2024 02:26:15 +0800 Subject: [PATCH 3/4] add extra headers --- llms/openai/internal/openaiclient/chat.go | 3 ++- llms/openai/openaillm.go | 1 + llms/options.go | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 407dad6dd..da5a3ec56 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -71,7 +71,8 @@ type ChatRequest struct { FunctionCallBehavior FunctionCallBehavior `json:"function_call,omitempty"` // Metadata allows you to specify additional information that will be passed to the model. - Metadata map[string]any `json:"metadata,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } // ToolType is the type of a tool. diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 699c0d304..bb404ef94 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -110,6 +110,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), Seed: opts.Seed, Metadata: opts.Metadata, + ExtraHeaders: opts.ExtraHeaders, } if opts.JSONMode { req.ResponseFormat = ResponseFormatJSON diff --git a/llms/options.go b/llms/options.go index 281b768a8..fda2cc4d0 100644 --- a/llms/options.go +++ b/llms/options.go @@ -61,6 +61,8 @@ type CallOptions struct { // Metadata is a map of metadata to include in the request. // The meaning of this field is specific to the backend in use. Metadata map[string]interface{} `json:"metadata,omitempty"` + + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` } // Tool is a tool that can be used by the model. @@ -265,3 +267,9 @@ func WithMetadata(metadata map[string]interface{}) CallOption { o.Metadata = metadata } } + +func WithExtraHeaders(headers map[string]string) CallOption { + return func(o *CallOptions) { + o.ExtraHeaders = headers + } +} From 29063ede57498b3cc28e347361c4bbc6d4d98633 Mon Sep 17 00:00:00 2001 From: lowjiansheng <15527690+lowjiansheng@users.noreply.github.com> Date: Mon, 4 Nov 2024 20:21:31 +0800 Subject: [PATCH 4/4] fix test errors --- llms/bedrock/bedrockllm_test.go | 4 +-- llms/generatecontent_test.go | 2 +- llms/googleai/shared_test/shared_test.go | 32 +++++++++++++----------- llms/openai/multicontent_test.go | 18 ++++++------- 4 files changed, 29 insertions(+), 27 deletions(-) diff --git a/llms/bedrock/bedrockllm_test.go b/llms/bedrock/bedrockllm_test.go index 2df294d8b..19c71d1ac 100644 --- a/llms/bedrock/bedrockllm_test.go +++ b/llms/bedrock/bedrockllm_test.go @@ -40,13 +40,13 @@ func TestAmazonOutput(t *testing.T) { { Role: llms.ChatMessageTypeSystem, Parts: []llms.ContentPart{ - llms.TextPart("You know all about AI."), + llms.TextPart("You know all about AI.", false), }, }, { Role: llms.ChatMessageTypeHuman, Parts: []llms.ContentPart{ - llms.TextPart("Explain AI in 10 words or less."), + llms.TextPart("Explain AI in 10 words or less.", false), }, }, } diff --git a/llms/generatecontent_test.go b/llms/generatecontent_test.go index 71ce77a9f..2257733e9 100644 --- a/llms/generatecontent_test.go +++ b/llms/generatecontent_test.go @@ -28,7 +28,7 @@ func TestTextParts(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - if got := TextParts(tt.args.role, tt.args.parts...); !reflect.DeepEqual(got, tt.want) { + if got := TextParts(tt.args.role, false, tt.args.parts...); !reflect.DeepEqual(got, tt.want) { t.Errorf("TextParts() = %v, want %v", got, tt.want) } }) diff --git a/llms/googleai/shared_test/shared_test.go b/llms/googleai/shared_test/shared_test.go index de72e2cca..6e7ab977b 100644 --- a/llms/googleai/shared_test/shared_test.go +++ b/llms/googleai/shared_test/shared_test.go @@ -125,8 +125,8 @@ func testMultiContentText(t *testing.T, llm llms.Model) { t.Parallel() parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("What kind of mammal am I?"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("What kind of mammal am I?", false), } content := []llms.MessageContent{ { @@ -151,6 +151,7 @@ func testMultiContentTextUsingTextParts(t *testing.T, llm llms.Model) { content := llms.TextParts( llms.ChatMessageTypeHuman, + false, "I'm a pomeranian", "What kind of mammal am I?", ) @@ -181,15 +182,15 @@ func testMultiContentTextChatSequence(t *testing.T, llm llms.Model) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name some countries")}, + Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)}, }, { Role: llms.ChatMessageTypeAI, - Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")}, + Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")}, + Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)}, }, } @@ -208,11 +209,11 @@ func testMultiContentWithSystemMessage(t *testing.T, llm llms.Model) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeSystem, - Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish")}, + Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits")}, + Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits", false)}, }, } @@ -232,7 +233,7 @@ func testMultiContentImageLink(t *testing.T, llm llms.Model) { llms.ImageURLPart( "https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true", ), - llms.TextPart("describe this image in detail"), + llms.TextPart("describe this image in detail", false), } content := []llms.MessageContent{ { @@ -264,7 +265,7 @@ func testMultiContentImageBinary(t *testing.T, llm llms.Model) { parts := []llms.ContentPart{ llms.BinaryPart("image/png", b), - llms.TextPart("what does this image show? please use detail"), + llms.TextPart("what does this image show? please use detail", false), } content := []llms.MessageContent{ { @@ -304,7 +305,7 @@ func testCandidateCountSetting(t *testing.T, llm llms.Model) { t.Helper() parts := []llms.ContentPart{ - llms.TextPart("Name five countries in Africa"), + llms.TextPart("Name five countries in Africa", false), } content := []llms.MessageContent{ { @@ -330,6 +331,7 @@ func testWithStreaming(t *testing.T, llm llms.Model) { content := llms.TextParts( llms.ChatMessageTypeHuman, + false, "I'm a pomeranian", "Tell me more about my taxonomy", ) @@ -376,7 +378,7 @@ func testTools(t *testing.T, llm llms.Model) { } content := []llms.MessageContent{ - llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"), + llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"), } resp, err := llm.GenerateContent( context.Background(), @@ -459,7 +461,7 @@ func testToolsWithInterfaceRequired(t *testing.T, llm llms.Model) { } content := []llms.MessageContent{ - llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"), + llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"), } resp, err := llm.GenerateContent( context.Background(), @@ -523,8 +525,8 @@ func testMaxTokensSetting(t *testing.T, llm llms.Model) { t.Parallel() parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("Describe my taxonomy, health and care"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("Describe my taxonomy, health and care", false), } content := []llms.MessageContent{ { @@ -566,7 +568,7 @@ func testWithHTTPClient(t *testing.T, llm llms.Model) { resp, err := llm.GenerateContent( context.TODO(), - []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "testing")}, + []llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, false, "testing")}, ) require.NoError(t, err) require.EqualValues(t, "test-ok", resp.Choices[0].Content) diff --git a/llms/openai/multicontent_test.go b/llms/openai/multicontent_test.go index 04bdfe139..863a38499 100644 --- a/llms/openai/multicontent_test.go +++ b/llms/openai/multicontent_test.go @@ -29,8 +29,8 @@ func TestMultiContentText(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("What kind of mammal am I?"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("What kind of mammal am I?", false), } content := []llms.MessageContent{ { @@ -54,15 +54,15 @@ func TestMultiContentTextChatSequence(t *testing.T) { content := []llms.MessageContent{ { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Name some countries")}, + Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)}, }, { Role: llms.ChatMessageTypeAI, - Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")}, + Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)}, }, { Role: llms.ChatMessageTypeHuman, - Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")}, + Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)}, }, } @@ -81,7 +81,7 @@ func TestMultiContentImage(t *testing.T) { parts := []llms.ContentPart{ llms.ImageURLPart("https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true"), - llms.TextPart("describe this image in detail"), + llms.TextPart("describe this image in detail", false), } content := []llms.MessageContent{ { @@ -103,8 +103,8 @@ func TestWithStreaming(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("I'm a pomeranian"), - llms.TextPart("Tell me more about my taxonomy"), + llms.TextPart("I'm a pomeranian", false), + llms.TextPart("Tell me more about my taxonomy", false), } content := []llms.MessageContent{ { @@ -134,7 +134,7 @@ func TestFunctionCall(t *testing.T) { llm := newTestClient(t) parts := []llms.ContentPart{ - llms.TextPart("What is the weather like in Boston?"), + llms.TextPart("What is the weather like in Boston?", false), } content := []llms.MessageContent{ {