Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support anthropic caching #1059

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llms/bedrock/bedrockllm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func TestAmazonOutput(t *testing.T) {
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{
llms.TextPart("You know all about AI."),
llms.TextPart("You know all about AI.", false),
},
},
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{
llms.TextPart("Explain AI in 10 words or less."),
llms.TextPart("Explain AI in 10 words or less.", false),
},
},
}
Expand Down
16 changes: 12 additions & 4 deletions llms/generatecontent.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ type MessageContent struct {
}

// TextPart creates TextContent from a given string.
func TextPart(s string) TextContent {
func TextPart(s string, isCacheable bool) TextContent {
if isCacheable {
return TextContent{Text: s, CacheControl: CacheControl{Type: "ephemeral"}}
}
return TextContent{Text: s}
}

Expand Down Expand Up @@ -52,7 +55,12 @@ type ContentPart interface {

// TextContent is content with some text.
type TextContent struct {
Text string
Text string
CacheControl CacheControl `json:"cache_control"`
}

type CacheControl struct {
Type string `json:"type"`
}

func (tc TextContent) String() string {
Expand Down Expand Up @@ -147,13 +155,13 @@ type ContentChoice struct {

// TextParts is a helper function to create a MessageContent with a role and a
// list of text parts.
func TextParts(role ChatMessageType, parts ...string) MessageContent {
func TextParts(role ChatMessageType, isCacheable bool, parts ...string) MessageContent {
result := MessageContent{
Role: role,
Parts: []ContentPart{},
}
for _, part := range parts {
result.Parts = append(result.Parts, TextPart(part))
result.Parts = append(result.Parts, TextPart(part, isCacheable))
}
return result
}
Expand Down
2 changes: 1 addition & 1 deletion llms/generatecontent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestTextParts(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := TextParts(tt.args.role, tt.args.parts...); !reflect.DeepEqual(got, tt.want) {
if got := TextParts(tt.args.role, false, tt.args.parts...); !reflect.DeepEqual(got, tt.want) {
t.Errorf("TextParts() = %v, want %v", got, tt.want)
}
})
Expand Down
32 changes: 17 additions & 15 deletions llms/googleai/shared_test/shared_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ func testMultiContentText(t *testing.T, llm llms.Model) {
t.Parallel()

parts := []llms.ContentPart{
llms.TextPart("I'm a pomeranian"),
llms.TextPart("What kind of mammal am I?"),
llms.TextPart("I'm a pomeranian", false),
llms.TextPart("What kind of mammal am I?", false),
}
content := []llms.MessageContent{
{
Expand All @@ -151,6 +151,7 @@ func testMultiContentTextUsingTextParts(t *testing.T, llm llms.Model) {

content := llms.TextParts(
llms.ChatMessageTypeHuman,
false,
"I'm a pomeranian",
"What kind of mammal am I?",
)
Expand Down Expand Up @@ -181,15 +182,15 @@ func testMultiContentTextChatSequence(t *testing.T, llm llms.Model) {
content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextPart("Name some countries")},
Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)},
},
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")},
Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)},
},
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")},
Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)},
},
}

Expand All @@ -208,11 +209,11 @@ func testMultiContentWithSystemMessage(t *testing.T, llm llms.Model) {
content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeSystem,
Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish")},
Parts: []llms.ContentPart{llms.TextPart("You are a Spanish teacher; answer in Spanish", false)},
},
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits")},
Parts: []llms.ContentPart{llms.TextPart("Name the 5 most common fruits", false)},
},
}

Expand All @@ -232,7 +233,7 @@ func testMultiContentImageLink(t *testing.T, llm llms.Model) {
llms.ImageURLPart(
"https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true",
),
llms.TextPart("describe this image in detail"),
llms.TextPart("describe this image in detail", false),
}
content := []llms.MessageContent{
{
Expand Down Expand Up @@ -264,7 +265,7 @@ func testMultiContentImageBinary(t *testing.T, llm llms.Model) {

parts := []llms.ContentPart{
llms.BinaryPart("image/png", b),
llms.TextPart("what does this image show? please use detail"),
llms.TextPart("what does this image show? please use detail", false),
}
content := []llms.MessageContent{
{
Expand Down Expand Up @@ -304,7 +305,7 @@ func testCandidateCountSetting(t *testing.T, llm llms.Model) {
t.Helper()

parts := []llms.ContentPart{
llms.TextPart("Name five countries in Africa"),
llms.TextPart("Name five countries in Africa", false),
}
content := []llms.MessageContent{
{
Expand All @@ -330,6 +331,7 @@ func testWithStreaming(t *testing.T, llm llms.Model) {

content := llms.TextParts(
llms.ChatMessageTypeHuman,
false,
"I'm a pomeranian",
"Tell me more about my taxonomy",
)
Expand Down Expand Up @@ -376,7 +378,7 @@ func testTools(t *testing.T, llm llms.Model) {
}

content := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"),
llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"),
}
resp, err := llm.GenerateContent(
context.Background(),
Expand Down Expand Up @@ -459,7 +461,7 @@ func testToolsWithInterfaceRequired(t *testing.T, llm llms.Model) {
}

content := []llms.MessageContent{
llms.TextParts(llms.ChatMessageTypeHuman, "What is the weather like in Chicago?"),
llms.TextParts(llms.ChatMessageTypeHuman, false, "What is the weather like in Chicago?"),
}
resp, err := llm.GenerateContent(
context.Background(),
Expand Down Expand Up @@ -523,8 +525,8 @@ func testMaxTokensSetting(t *testing.T, llm llms.Model) {
t.Parallel()

parts := []llms.ContentPart{
llms.TextPart("I'm a pomeranian"),
llms.TextPart("Describe my taxonomy, health and care"),
llms.TextPart("I'm a pomeranian", false),
llms.TextPart("Describe my taxonomy, health and care", false),
}
content := []llms.MessageContent{
{
Expand Down Expand Up @@ -566,7 +568,7 @@ func testWithHTTPClient(t *testing.T, llm llms.Model) {

resp, err := llm.GenerateContent(
context.TODO(),
[]llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, "testing")},
[]llms.MessageContent{llms.TextParts(llms.ChatMessageTypeHuman, false, "testing")},
)
require.NoError(t, err)
require.EqualValues(t, "test-ok", resp.Choices[0].Content)
Expand Down
3 changes: 2 additions & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ type ChatRequest struct {
FunctionCallBehavior FunctionCallBehavior `json:"function_call,omitempty"`

// Metadata allows you to specify additional information that will be passed to the model.
Metadata map[string]any `json:"metadata,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
}

// ToolType is the type of a tool.
Expand Down
18 changes: 9 additions & 9 deletions llms/openai/multicontent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func TestMultiContentText(t *testing.T) {
llm := newTestClient(t)

parts := []llms.ContentPart{
llms.TextPart("I'm a pomeranian"),
llms.TextPart("What kind of mammal am I?"),
llms.TextPart("I'm a pomeranian", false),
llms.TextPart("What kind of mammal am I?", false),
}
content := []llms.MessageContent{
{
Expand All @@ -54,15 +54,15 @@ func TestMultiContentTextChatSequence(t *testing.T) {
content := []llms.MessageContent{
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextPart("Name some countries")},
Parts: []llms.ContentPart{llms.TextPart("Name some countries", false)},
},
{
Role: llms.ChatMessageTypeAI,
Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho")},
Parts: []llms.ContentPart{llms.TextPart("Spain and Lesotho", false)},
},
{
Role: llms.ChatMessageTypeHuman,
Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?")},
Parts: []llms.ContentPart{llms.TextPart("Which if these is larger?", false)},
},
}

Expand All @@ -81,7 +81,7 @@ func TestMultiContentImage(t *testing.T) {

parts := []llms.ContentPart{
llms.ImageURLPart("https://github.com/tmc/langchaingo/blob/main/docs/static/img/parrot-icon.png?raw=true"),
llms.TextPart("describe this image in detail"),
llms.TextPart("describe this image in detail", false),
}
content := []llms.MessageContent{
{
Expand All @@ -103,8 +103,8 @@ func TestWithStreaming(t *testing.T) {
llm := newTestClient(t)

parts := []llms.ContentPart{
llms.TextPart("I'm a pomeranian"),
llms.TextPart("Tell me more about my taxonomy"),
llms.TextPart("I'm a pomeranian", false),
llms.TextPart("Tell me more about my taxonomy", false),
}
content := []llms.MessageContent{
{
Expand Down Expand Up @@ -134,7 +134,7 @@ func TestFunctionCall(t *testing.T) {
llm := newTestClient(t)

parts := []llms.ContentPart{
llms.TextPart("What is the weather like in Boston?"),
llms.TextPart("What is the weather like in Boston?", false),
}
content := []llms.MessageContent{
{
Expand Down
1 change: 1 addition & 0 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior),
Seed: opts.Seed,
Metadata: opts.Metadata,
ExtraHeaders: opts.ExtraHeaders,
}
if opts.JSONMode {
req.ResponseFormat = ResponseFormatJSON
Expand Down
11 changes: 3 additions & 8 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ type CallOptions struct {
// Supported MIME types are: text/plain: (default) Text output.
// application/json: JSON response in the response candidates.
ResponseMIMEType string `json:"response_mime_type,omitempty"`

// ExtraHeaders is a map of extra headers to include in the request.
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
}

// Tool is a tool that can be used by the model.
Expand Down Expand Up @@ -272,11 +275,3 @@ func WithMetadata(metadata map[string]interface{}) CallOption {
o.Metadata = metadata
}
}

// WithResponseMIMEType will add an option to set the ResponseMIMEType
// Currently only supported by googleai llms.
func WithResponseMIMEType(responseMIMEType string) CallOption {
return func(o *CallOptions) {
o.ResponseMIMEType = responseMIMEType
}
}