From 8dc9ed8f76b85e3650e06a7111a76394e057c70b Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Fri, 7 Feb 2025 18:07:05 -0800 Subject: [PATCH] feat(go): Added `/util/generate` + filled feature gaps in Generate API. (#1818) --- genkit-tools/common/src/types/model.ts | 2 + genkit-tools/genkit-schema.json | 3 + go/ai/action_test.go | 145 ++++++++++ go/ai/document.go | 89 +++--- go/ai/gen.go | 91 ++++-- go/ai/generate.go | 365 ++++++++++++++++++++----- go/ai/generator_test.go | 219 ++++++++++++++- go/ai/tools.go | 67 ++++- go/core/action.go | 7 +- go/core/schemas.config | 46 +++- go/genkit/genkit.go | 8 +- go/go.mod | 2 +- go/internal/atype/atype.go | 1 + go/internal/base/validation.go | 7 +- go/internal/doc-snippets/models.go | 2 +- go/internal/doc-snippets/prompts.go | 2 +- go/plugins/dotprompt/dotprompt.go | 77 ++++-- go/plugins/dotprompt/dotprompt_test.go | 3 +- go/plugins/dotprompt/genkit.go | 72 ++++- go/plugins/googleai/googleai.go | 21 +- go/plugins/googleai/googleai_test.go | 2 +- go/plugins/ollama/ollama_live_test.go | 2 +- go/plugins/vertexai/vertexai.go | 21 +- go/plugins/vertexai/vertexai_test.go | 2 +- go/samples/menu/s02.go | 2 +- go/samples/menu/s03.go | 2 +- 26 files changed, 1067 insertions(+), 193 deletions(-) create mode 100644 go/ai/action_test.go diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index 7846f55bd..cd92b7a25 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -122,6 +122,8 @@ export const ModelInfoSchema = z.object({ context: z.boolean().optional(), /** Model can natively support constrained generation. */ constrained: z.enum(['none', 'all', 'no-tools']).optional(), + /** Model supports controlling tool choice, e.g. forced tool calling. */ + toolChoice: z.boolean().optional(), }) .optional(), }); diff --git a/genkit-tools/genkit-schema.json b/genkit-tools/genkit-schema.json index 6a6010aa8..8927aaedc 100644 --- a/genkit-tools/genkit-schema.json +++ b/genkit-tools/genkit-schema.json @@ -766,6 +766,9 @@ "all", "no-tools" ] + }, + "toolChoice": { + "type": "boolean" } }, "additionalProperties": false diff --git a/go/ai/action_test.go b/go/ai/action_test.go new file mode 100644 index 000000000..297eb2eec --- /dev/null +++ b/go/ai/action_test.go @@ -0,0 +1,145 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ai + +import ( + "context" + "os" + "testing" + + "github.com/firebase/genkit/go/internal/registry" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "gopkg.in/yaml.v3" +) + +type specSuite struct { + Tests []testCase `yaml:"tests"` +} + +type testCase struct { + Name string `yaml:"name"` + Input *GenerateActionOptions `yaml:"input"` + StreamChunks [][]*ModelResponseChunk `yaml:"streamChunks,omitempty"` + ModelResponses []*ModelResponse `yaml:"modelResponses"` + ExpectResponse *ModelResponse `yaml:"expectResponse,omitempty"` + Stream bool `yaml:"stream,omitempty"` + ExpectChunks []*ModelResponseChunk `yaml:"expectChunks,omitempty"` +} + +type programmableModel struct { + r *registry.Registry + handleResp func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) + lastRequest *ModelRequest +} + +func (pm *programmableModel) Name() string { + return "programmableModel" +} + +func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) { + pm.lastRequest = req + return pm.handleResp(ctx, req, cb) +} + +func defineProgrammableModel(r *registry.Registry) *programmableModel { + pm := &programmableModel{r: r} + DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { + return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb) + }) + return pm +} + +func TestGenerateAction(t *testing.T) { + data, err := os.ReadFile("../../tests/specs/generate.yaml") + if err != nil { + t.Fatalf("failed to read spec file: %v", err) + } + + var suite specSuite + if err := yaml.Unmarshal(data, &suite); err != nil { + t.Fatalf("failed to parse spec file: %v", err) + } + + for _, tc := range suite.Tests { + t.Run(tc.Name, func(t *testing.T) { + ctx := context.Background() + + r, err := registry.New() + if err != nil { + t.Fatalf("failed to create registry: %v", err) + } + + pm := defineProgrammableModel(r) + + DefineTool(r, "testTool", "description", + func(ctx *ToolContext, input any) (any, error) { + return "tool called", nil + }) + + if len(tc.ModelResponses) > 0 || len(tc.StreamChunks) > 0 { + reqCounter := 0 + pm.handleResp = func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) { + if len(tc.StreamChunks) > 0 && cb != nil { + for _, chunk := range tc.StreamChunks[reqCounter] { + if err := cb(ctx, chunk); err != nil { + return nil, err + } + } + } + resp := tc.ModelResponses[reqCounter] + resp.Request = req + resp.Custom = map[string]any{} + resp.Request.Output = &ModelRequestOutput{} + resp.Usage = &GenerationUsage{} + reqCounter++ + return resp, nil + } + } + + genAction := DefineGenerateAction(ctx, r) + + if tc.Stream { + chunks := []*ModelResponseChunk{} + streamCb := func(ctx context.Context, chunk *ModelResponseChunk) error { + chunks = append(chunks, chunk) + return nil + } + + resp, err := genAction.Run(ctx, tc.Input, streamCb) + if err != nil { + t.Fatalf("action failed: %v", err) + } + + if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" { + t.Errorf("chunks mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" { + t.Errorf("response mismatch (-want +got):\n%s", diff) + } + } else { + resp, err := genAction.Run(ctx, tc.Input, nil) + if err != nil { + t.Fatalf("action failed: %v", err) + } + + if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" { + t.Errorf("response mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/go/ai/document.go b/go/ai/document.go index da753cf6d..53c2368b7 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -1,12 +1,13 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( "encoding/json" "fmt" + + "gopkg.in/yaml.v3" ) // A Document is a piece of data that can be embedded, indexed, or retrieved. @@ -21,11 +22,12 @@ type Document struct { // A Part is one part of a [Document]. This may be plain text or it // may be a URL (possibly a "data:" URL with embedded data). type Part struct { - Kind PartKind `json:"kind,omitempty"` - ContentType string `json:"contentType,omitempty"` // valid for kind==blob - Text string `json:"text,omitempty"` // valid for kind∈{text,blob} - ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest - ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse + Kind PartKind `json:"kind,omitempty"` + ContentType string `json:"contentType,omitempty"` // valid for kind==blob + Text string `json:"text,omitempty"` // valid for kind∈{text,blob} + ToolRequest *ToolRequest `json:"toolRequest,omitempty"` // valid for kind==partToolRequest + ToolResponse *ToolResponse `json:"toolResponse,omitempty"` // valid for kind==partToolResponse + Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds } type PartKind int8 @@ -105,7 +107,8 @@ func (p *Part) MarshalJSON() ([]byte, error) { switch p.Kind { case PartText: v := textPart{ - Text: p.Text, + Text: p.Text, + Metadata: p.Metadata, } return json.Marshal(v) case PartMedia: @@ -114,28 +117,25 @@ func (p *Part) MarshalJSON() ([]byte, error) { ContentType: p.ContentType, Url: p.Text, }, + Metadata: p.Metadata, } return json.Marshal(v) case PartData: v := dataPart{ - Data: p.Text, + Data: p.Text, + Metadata: p.Metadata, } return json.Marshal(v) case PartToolRequest: - // TODO: make sure these types marshal/unmarshal nicely - // between Go and javascript. At the very least the - // field name needs to change (here and in UnmarshalJSON). - v := struct { - ToolReq *ToolRequest `json:"toolreq,omitempty"` - }{ - ToolReq: p.ToolRequest, + v := toolRequestPart{ + ToolRequest: p.ToolRequest, + Metadata: p.Metadata, } return json.Marshal(v) case PartToolResponse: - v := struct { - ToolResp *ToolResponse `json:"toolresp,omitempty"` - }{ - ToolResp: p.ToolResponse, + v := toolResponsePart{ + ToolResponse: p.ToolResponse, + Metadata: p.Metadata, } return json.Marshal(v) default: @@ -144,34 +144,27 @@ func (p *Part) MarshalJSON() ([]byte, error) { } type partSchema struct { - Text string `json:"text,omitempty"` - Media *mediaPartMedia `json:"media,omitempty"` - Data string `json:"data,omitempty"` - ToolReq *ToolRequest `json:"toolreq,omitempty"` - ToolResp *ToolResponse `json:"toolresp,omitempty"` + Text string `json:"text,omitempty" yaml:"text,omitempty"` + Media *mediaPartMedia `json:"media,omitempty" yaml:"media,omitempty"` + Data string `json:"data,omitempty" yaml:"data,omitempty"` + ToolRequest *ToolRequest `json:"toolRequest,omitempty" yaml:"toolRequest,omitempty"` + ToolResponse *ToolResponse `json:"toolResponse,omitempty" yaml:"toolResponse,omitempty"` + Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty"` } -// UnmarshalJSON is called by the JSON unmarshaler to read a Part. -func (p *Part) UnmarshalJSON(b []byte) error { - // This is not handled by the schema generator because - // Part is defined in TypeScript as a union. - - var s partSchema - if err := json.Unmarshal(b, &s); err != nil { - return err - } - +// unmarshalPartFromSchema updates Part p based on the schema s. +func (p *Part) unmarshalPartFromSchema(s partSchema) { switch { case s.Media != nil: p.Kind = PartMedia p.Text = s.Media.Url p.ContentType = s.Media.ContentType - case s.ToolReq != nil: + case s.ToolRequest != nil: p.Kind = PartToolRequest - p.ToolRequest = s.ToolReq - case s.ToolResp != nil: + p.ToolRequest = s.ToolRequest + case s.ToolResponse != nil: p.Kind = PartToolResponse - p.ToolResponse = s.ToolResp + p.ToolResponse = s.ToolResponse default: p.Kind = PartText p.Text = s.Text @@ -182,6 +175,26 @@ func (p *Part) UnmarshalJSON(b []byte) error { p.Text = s.Data } } + p.Metadata = s.Metadata +} + +// UnmarshalJSON is called by the JSON unmarshaler to read a Part. +func (p *Part) UnmarshalJSON(b []byte) error { + var s partSchema + if err := json.Unmarshal(b, &s); err != nil { + return err + } + p.unmarshalPartFromSchema(s) + return nil +} + +// UnmarshalYAML implements yaml.Unmarshaler for Part. +func (p *Part) UnmarshalYAML(value *yaml.Node) error { + var s partSchema + if err := value.Decode(&s); err != nil { + return err + } + p.unmarshalPartFromSchema(s) return nil } diff --git a/go/ai/gen.go b/go/ai/gen.go index 534bde39a..265051cd3 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -1,4 +1,4 @@ -// Copyright 2024 Google LLC +// Copyright 2025 Google LLC // SPDX-License-Identifier: Apache-2.0 // This file was generated by jsonschemagen. DO NOT EDIT. @@ -10,6 +10,43 @@ type dataPart struct { Metadata map[string]any `json:"metadata,omitempty"` } +type FinishReason string + +const ( + FinishReasonStop FinishReason = "stop" + FinishReasonLength FinishReason = "length" + FinishReasonBlocked FinishReason = "blocked" + FinishReasonInterrupted FinishReason = "interrupted" + FinishReasonOther FinishReason = "other" + FinishReasonUnknown FinishReason = "unknown" +) + +type GenerateActionOptions struct { + Config any `json:"config,omitempty"` + Docs []*Document `json:"docs,omitempty"` + MaxTurns int `json:"maxTurns,omitempty"` + Messages []*Message `json:"messages,omitempty"` + Model string `json:"model,omitempty"` + Output *GenerateActionOptionsOutput `json:"output,omitempty"` + ReturnToolRequests bool `json:"returnToolRequests,omitempty"` + ToolChoice ToolChoice `json:"toolChoice,omitempty"` + Tools []string `json:"tools,omitempty"` +} + +type GenerateActionOptionsOutput struct { + ContentType string `json:"contentType,omitempty"` + Format OutputFormat `json:"format,omitempty"` + JsonSchema map[string]any `json:"jsonSchema,omitempty"` +} + +type ToolChoice string + +const ( + ToolChoiceAuto ToolChoice = "auto" + ToolChoiceRequired ToolChoice = "required" + ToolChoiceNone ToolChoice = "none" +) + type ModelRequestOutput struct { Format OutputFormat `json:"format,omitempty"` Schema map[string]any `json:"schema,omitempty"` @@ -74,12 +111,15 @@ type ModelInfo struct { } type ModelInfoSupports struct { - Context bool `json:"context,omitempty"` - Media bool `json:"media,omitempty"` - Multiturn bool `json:"multiturn,omitempty"` - Output OutputFormat `json:"output,omitempty"` - SystemRole bool `json:"systemRole,omitempty"` - Tools bool `json:"tools,omitempty"` + Constrained bool `json:"constrained,omitempty"` + ContentType []string `json:"contentType,omitempty"` + Context bool `json:"context,omitempty"` + Media bool `json:"media,omitempty"` + Multiturn bool `json:"multiturn,omitempty"` + Output OutputFormat `json:"output,omitempty"` + SystemRole bool `json:"systemRole,omitempty"` + ToolChoice bool `json:"toolChoice,omitempty"` + Tools bool `json:"tools,omitempty"` } // A ModelRequest is a request to generate completions from a model. @@ -88,7 +128,8 @@ type ModelRequest struct { Context []any `json:"context,omitempty"` Messages []*Message `json:"messages,omitempty"` // Output describes the desired response format. - Output *ModelRequestOutput `json:"output,omitempty"` + Output *ModelRequestOutput `json:"output,omitempty"` + ToolChoice ToolChoice `json:"toolChoice,omitempty"` // Tools lists the available tools that the model can ask the client to run. Tools []*ToolDefinition `json:"tools,omitempty"` } @@ -113,18 +154,10 @@ type ModelResponseChunk struct { Aggregated bool `json:"aggregated,omitempty"` Content []*Part `json:"content,omitempty"` Custom any `json:"custom,omitempty"` + Index int `json:"index,omitempty"` + Role Role `json:"role,omitempty"` } -type FinishReason string - -const ( - FinishReasonStop FinishReason = "stop" - FinishReasonLength FinishReason = "length" - FinishReasonBlocked FinishReason = "blocked" - FinishReasonOther FinishReason = "other" - FinishReasonUnknown FinishReason = "unknown" -) - // Role indicates which entity is responsible for the content of a message. type Role string @@ -150,19 +183,32 @@ type ToolDefinition struct { Description string `json:"description,omitempty"` // Valid JSON Schema representing the input of the tool. InputSchema map[string]any `json:"inputSchema,omitempty"` - Name string `json:"name,omitempty"` + // additional metadata for this tool definition + Metadata map[string]any `json:"metadata,omitempty"` + Name string `json:"name,omitempty"` // Valid JSON Schema describing the output of the tool. OutputSchema map[string]any `json:"outputSchema,omitempty"` } +type toolRequestPart struct { + Metadata map[string]any `json:"metadata,omitempty"` + ToolRequest *ToolRequest `json:"toolRequest,omitempty"` +} + // A ToolRequest is a message from the model to the client that it should run a // specific tool and pass a [ToolResponse] to the model on the next chat request it makes. // Any ToolRequest will correspond to some [ToolDefinition] previously sent by the client. type ToolRequest struct { // Input is a JSON object describing the input values to the tool. // An example might be map[string]any{"country":"USA", "president":3}. - Input map[string]any `json:"input,omitempty"` - Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Name string `json:"name,omitempty"` + Ref string `json:"ref,omitempty"` +} + +type toolResponsePart struct { + Metadata map[string]any `json:"metadata,omitempty"` + ToolResponse *ToolResponse `json:"toolResponse,omitempty"` } // A ToolResponse is a message from the client to the model containing @@ -172,5 +218,6 @@ type ToolResponse struct { Name string `json:"name,omitempty"` // Output is a JSON object describing the results of running the tool. // An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. - Output map[string]any `json:"output,omitempty"` + Output any `json:"output,omitempty"` + Ref string `json:"ref,omitempty"` } diff --git a/go/ai/generate.go b/go/ai/generate.go index 7356ceeef..cf4fe202a 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -8,12 +8,12 @@ import ( "encoding/json" "errors" "fmt" - "slices" "strconv" "strings" "github.com/firebase/genkit/go/core" "github.com/firebase/genkit/go/core/logger" + "github.com/firebase/genkit/go/core/tracing" "github.com/firebase/genkit/go/internal/atype" "github.com/firebase/genkit/go/internal/base" "github.com/firebase/genkit/go/internal/registry" @@ -24,16 +24,82 @@ type Model interface { // Name returns the registry name of the model. Name() string // Generate applies the [Model] to provided request, handling tool requests and handles streaming. - Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) + Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) } type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] +type generateAction = core.Action[*GenerateActionOptions, *ModelResponse, *ModelResponseChunk] + // ModelStreamingCallback is the type for the streaming callback of a model. type ModelStreamingCallback = func(context.Context, *ModelResponseChunk) error +// ToolConfig handles configuration around tool calls during generation. +type ToolConfig struct { + MaxTurns int + ReturnToolRequests bool +} + +// DefineGenerateAction defines a utility generate action. +func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction { + return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{}, + func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) { + logger.FromContext(ctx).Debug("GenerateAction", + "input", fmt.Sprintf("%#v", req)) + defer func() { + logger.FromContext(ctx).Debug("GenerateAction", + "output", fmt.Sprintf("%#v", output), + "err", err) + }() + return tracing.RunInNewSpan(ctx, r.TracingState(), "generate", "util", false, req, + func(ctx context.Context, input *GenerateActionOptions) (*ModelResponse, error) { + model := LookupModel(r, "default", req.Model) + if model == nil { + return nil, fmt.Errorf("model %q not found", req.Model) + } + + toolDefs := make([]*ToolDefinition, len(req.Tools)) + for i, toolName := range req.Tools { + toolDefs[i] = LookupTool(r, toolName).Definition() + } + + modelReq := &ModelRequest{ + Messages: req.Messages, + Config: req.Config, + Tools: toolDefs, + ToolChoice: req.ToolChoice, + } + + if req.Output != nil { + modelReq.Output = &ModelRequestOutput{ + Format: req.Output.Format, + Schema: req.Output.JsonSchema, + } + } + + if modelReq.Output != nil && + modelReq.Output.Schema != nil && + modelReq.Output.Format == "" { + modelReq.Output.Format = OutputFormatJSON + } + + maxTurns := 5 + if req.MaxTurns > 0 { + maxTurns = req.MaxTurns + } + + toolCfg := &ToolConfig{ + MaxTurns: maxTurns, + ReturnToolRequests: req.ReturnToolRequests, + } + + return model.Generate(ctx, r, modelReq, toolCfg, cb) + }) + })) +} + // DefineModel registers the given generate function as an action, and returns a // [Model] that runs it. func DefineModel( @@ -85,11 +151,13 @@ func LookupModel(r *registry.Registry, provider, name string) Model { // generateParams represents various params of the Generate call. type generateParams struct { - Request *ModelRequest - Model Model - Stream ModelStreamingCallback - SystemPrompt *Message - History []*Message + Request *ModelRequest + Model Model + Stream ModelStreamingCallback + History []*Message + SystemPrompt *Message + MaxTurns int + ReturnToolRequests bool } // GenerateOption configures params of the Generate call. @@ -116,7 +184,7 @@ func WithTextPrompt(prompt string) GenerateOption { func WithSystemPrompt(prompt string) GenerateOption { return func(req *generateParams) error { if req.SystemPrompt != nil { - return errors.New("cannot set system prompt (WithSystemPrompt) more than once") + return errors.New("generate.WithSystemPrompt: cannot set system prompt more than once") } req.SystemPrompt = NewSystemTextMessage(prompt) return nil @@ -138,7 +206,7 @@ func WithMessages(messages ...*Message) GenerateOption { func WithHistory(history ...*Message) GenerateOption { return func(req *generateParams) error { if req.History != nil { - return errors.New("cannot set history (WithHistory) more than once") + return errors.New("generate.WithHistory: cannot set history more than once") } req.History = history return nil @@ -149,7 +217,7 @@ func WithHistory(history ...*Message) GenerateOption { func WithConfig(config any) GenerateOption { return func(req *generateParams) error { if req.Request.Config != nil { - return errors.New("cannot set Request.Config (WithConfig) more than once") + return errors.New("generate.WithConfig: cannot set config more than once") } req.Request.Config = config return nil @@ -168,7 +236,7 @@ func WithContext(c ...any) GenerateOption { func WithTools(tools ...Tool) GenerateOption { return func(req *generateParams) error { if req.Request.Tools != nil { - return errors.New("cannot set Request.Tools (WithTools) more than once") + return errors.New("generate.WithTools: cannot set tools more than once") } var toolDefs []*ToolDefinition for _, t := range tools { @@ -183,7 +251,7 @@ func WithTools(tools ...Tool) GenerateOption { func WithOutputSchema(schema any) GenerateOption { return func(req *generateParams) error { if req.Request.Output != nil && req.Request.Output.Schema != nil { - return errors.New("cannot set Request.Output.Schema (WithOutputSchema) more than once") + return errors.New("generate.WithOutputSchema: cannot set output schema more than once") } if req.Request.Output == nil { req.Request.Output = &ModelRequestOutput{} @@ -209,24 +277,62 @@ func WithOutputFormat(format OutputFormat) GenerateOption { func WithStreaming(cb ModelStreamingCallback) GenerateOption { return func(req *generateParams) error { if req.Stream != nil { - return errors.New("cannot set streaming callback (WithStreaming) more than once") + return errors.New("generate.WithStreaming: cannot set streaming callback more than once") } req.Stream = cb return nil } } +// WithMaxTurns sets the maximum number of tool call iterations for the generate request. +func WithMaxTurns(maxTurns int) GenerateOption { + return func(req *generateParams) error { + if maxTurns <= 0 { + return fmt.Errorf("maxTurns must be greater than 0, got %d", maxTurns) + } + if req.MaxTurns != 0 { + return errors.New("generate.WithMaxTurns: cannot set MaxTurns more than once") + } + req.MaxTurns = maxTurns + return nil + } +} + +// WithReturnToolRequests configures whether to return tool requests instead of making the tool calls and continuing the generation. +func WithReturnToolRequests(returnToolRequests bool) GenerateOption { + return func(req *generateParams) error { + if req.ReturnToolRequests { + return errors.New("generate.WithReturnToolRequests: cannot set ReturnToolRequests more than once") + } + req.ReturnToolRequests = returnToolRequests + return nil + } +} + +// WithToolChoice configures whether tool calls are required, disabled, or optional for the generate request. +func WithToolChoice(toolChoice ToolChoice) GenerateOption { + return func(req *generateParams) error { + if req.Request.ToolChoice != "" { + return errors.New("generate.WithToolChoice: cannot set ToolChoice more than once") + } + req.Request.ToolChoice = toolChoice + return nil + } +} + // Generate run generate request for this model. Returns ModelResponse struct. func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) { req := &generateParams{ Request: &ModelRequest{}, } + for _, with := range opts { err := with(req) if err != nil { return nil, err } } + if req.Model == nil { return nil, errors.New("model is required") } @@ -253,8 +359,16 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) req.Request.Messages = []*Message{req.SystemPrompt} req.Request.Messages = append(req.Request.Messages, prev...) } + if req.MaxTurns == 0 { + req.MaxTurns = 1 + } - return req.Model.Generate(ctx, r, req.Request, req.Stream) + toolCfg := &ToolConfig{ + MaxTurns: req.MaxTurns, + ReturnToolRequests: req.ReturnToolRequests, + } + + return req.Model.Generate(ctx, r, req.Request, toolCfg, req.Stream) } // validateModelVersion checks in the registry the action of the @@ -321,17 +435,37 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ... } // Generate applies the [Action] to provided request, handling tool requests and handles streaming. -func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { +func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) { if m == nil { return nil, errors.New("Generate called on a nil Model; check that all models are defined") } + + if toolCfg == nil { + toolCfg = &ToolConfig{ + MaxTurns: 1, + ReturnToolRequests: false, + } + } + + // TODO: Add warnings if the model does not support certain configuration options. + + if req.Tools != nil { + toolNames := make(map[string]bool) + for _, tool := range req.Tools { + if toolNames[tool.Name] { + return nil, fmt.Errorf("duplicate tool name found: %q", tool.Name) + } + toolNames[tool.Name] = true + } + } + if err := conformOutput(req); err != nil { return nil, err } - a := (*core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk])(m) + currentTurn := 0 for { - resp, err := a.Run(ctx, req, cb) + resp, err := (*modelAction)(m).Run(ctx, req, cb) if err != nil { return nil, err } @@ -342,20 +476,168 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req } resp.Message = msg - newReq, err := handleToolRequest(ctx, r, req, resp) + toolCount := 0 + for _, part := range resp.Message.Content { + if part.IsToolRequest() { + toolCount++ + } + } + if toolCount == 0 || toolCfg.ReturnToolRequests { + return resp, nil + } + + if currentTurn+1 > toolCfg.MaxTurns { + return nil, fmt.Errorf("exceeded maximum tool call iterations (%d)", toolCfg.MaxTurns) + } + + newReq, interruptMsg, err := handleToolRequests(ctx, r, req, resp, cb) if err != nil { return nil, err } + if interruptMsg != nil { + resp.FinishReason = "interrupted" + resp.FinishMessage = "One or more tool calls resulted in interrupts." + resp.Message = interruptMsg + return resp, nil + } if newReq == nil { return resp, nil } req = newReq + currentTurn++ } } func (i *modelActionDef) Name() string { return (*modelAction)(i).Name() } +// cloneMessage creates a deep copy of the provided Message. +func cloneMessage(m *Message) *Message { + if m == nil { + return nil + } + + bytes, err := json.Marshal(m) + if err != nil { + panic(fmt.Sprintf("failed to marshal message: %v", err)) + } + + var copy Message + if err := json.Unmarshal(bytes, ©); err != nil { + panic(fmt.Sprintf("failed to unmarshal message: %v", err)) + } + + return © +} + +// handleToolRequests processes any tool requests in the response, returning either a new request to continue the conversation or nil if no tool requests need handling. +func handleToolRequests(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse, cb ModelStreamingCallback) (*ModelRequest, *Message, error) { + toolCount := 0 + for _, part := range resp.Message.Content { + if part.IsToolRequest() { + toolCount++ + } + } + + if toolCount == 0 { + return nil, nil, nil + } + + type toolResult struct { + index int + output any + err error + } + + resultChan := make(chan toolResult) + toolMessage := &Message{Role: RoleTool} + revisedMessage := cloneMessage(resp.Message) + + for i, part := range resp.Message.Content { + if !part.IsToolRequest() { + continue + } + + go func(idx int, p *Part) { + toolReq := p.ToolRequest + tool := LookupTool(r, toolReq.Name) + if tool == nil { + resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q not found", toolReq.Name)} + return + } + + output, err := tool.RunRaw(ctx, toolReq.Input) + if err != nil { + var interruptErr *ToolInterruptError + if errors.As(err, &interruptErr) { + logger.FromContext(ctx).Debug("tool %q triggered an interrupt: %v", toolReq.Name, interruptErr.Metadata) + revisedMessage.Content[idx] = &Part{ + ToolRequest: toolReq, + Metadata: map[string]any{ + "interrupt": interruptErr.Metadata, + }, + } + resultChan <- toolResult{idx, nil, interruptErr} + return + } + resultChan <- toolResult{idx, nil, fmt.Errorf("tool %q failed: %w", toolReq.Name, err)} + return + } + + revisedMessage.Content[idx] = &Part{ + ToolRequest: toolReq, + Metadata: map[string]any{ + "pendingOutput": output, + }, + } + + resultChan <- toolResult{idx, output, nil} + }(i, part) + } + + var toolResponses []*Part + hasInterrupts := false + for i := 0; i < toolCount; i++ { + result := <-resultChan + if result.err != nil { + var interruptErr *ToolInterruptError + if errors.As(result.err, &interruptErr) { + hasInterrupts = true + continue + } + return nil, nil, result.err + } + + toolReq := resp.Message.Content[result.index].ToolRequest + toolResponses = append(toolResponses, NewToolResponsePart(&ToolResponse{ + Name: toolReq.Name, + Ref: toolReq.Ref, + Output: result.output, + })) + } + + if hasInterrupts { + return nil, revisedMessage, nil + } + + toolMessage.Content = toolResponses + + if cb != nil { + err := cb(ctx, &ModelResponseChunk{ + Content: toolMessage.Content, + Role: RoleTool, + }) + if err != nil { + return nil, nil, fmt.Errorf("streaming callback failed: %w", err) + } + } + + newReq := req + newReq.Messages = append(append([]*Message{}, req.Messages...), resp.Message, toolMessage) + + return newReq, nil, nil +} + // conformOutput appends a message to the request indicating conformance to the expected schema. func conformOutput(req *ModelRequest) error { if req.Output != nil && req.Output.Format == OutputFormatJSON && len(req.Messages) > 0 { @@ -375,12 +657,11 @@ func conformOutput(req *ModelRequest) error { // It will strip JSON markdown delimiters from the response. func validResponse(ctx context.Context, resp *ModelResponse) (*Message, error) { msg, err := validMessage(resp.Message, resp.Request.Output) - if err == nil { - return msg, nil - } else { + if err != nil { logger.FromContext(ctx).Debug("message did not match expected schema", "error", err.Error()) return nil, errors.New("generation did not result in a message matching expected schema") } + return msg, nil } // validMessage will validate the message against the expected schema. @@ -409,48 +690,6 @@ func validMessage(m *Message, output *ModelRequestOutput) (*Message, error) { return m, nil } -// handleToolRequest checks if a tool was requested by a model. -// If a tool was requested, this runs the tool and returns an -// updated ModelRequest. If no tool was requested this returns nil. -func handleToolRequest(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse) (*ModelRequest, error) { - msg := resp.Message - if msg == nil || len(msg.Content) == 0 { - return nil, nil - } - part := msg.Content[0] - if !part.IsToolRequest() { - return nil, nil - } - - toolReq := part.ToolRequest - tool := LookupTool(r, toolReq.Name) - if tool == nil { - return nil, fmt.Errorf("tool %v not found", toolReq.Name) - } - to, err := tool.RunRaw(ctx, toolReq.Input) - if err != nil { - return nil, err - } - - toolResp := &Message{ - Content: []*Part{ - NewToolResponsePart(&ToolResponse{ - Name: toolReq.Name, - Output: map[string]any{ - "response": to, - }, - }), - }, - Role: RoleTool, - } - - // Copy the ModelRequest rather than modifying it. - rreq := *req - rreq.Messages = append(slices.Clip(rreq.Messages), msg, toolResp) - - return &rreq, nil -} - // Text returns the contents of the first candidate in a // [ModelResponse] as a string. It returns an empty string if there // are no candidates or if the candidate has no message. diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index eb515e22b..36337c423 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -5,6 +5,7 @@ package ai import ( "context" + "fmt" "math" "strings" "testing" @@ -57,7 +58,7 @@ var ( // with tools var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken", - func(ctx context.Context, input struct { + func(ctx *ToolContext, input struct { Value float64 Over float64 }, @@ -333,6 +334,222 @@ func TestGenerate(t *testing.T) { t.Errorf("Request diff (+got -want):\n%s", diff) } }) + + t.Run("handles tool interrupts", func(t *testing.T) { + interruptTool := DefineTool(r, "interruptor", "always interrupts", + func(ctx *ToolContext, input any) (any, error) { + return nil, ctx.Interrupt(&InterruptOptions{ + Metadata: map[string]any{ + "reason": "test interrupt", + }, + }) + }, + ) + + interruptModel := DefineModel(r, "test", "interrupt", nil, + func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "interruptor", + Input: nil, + }), + }, + }, + }, nil + }) + + res, err := Generate(context.Background(), r, + WithModel(interruptModel), + WithTextPrompt("trigger interrupt"), + WithTools(interruptTool), + ) + if err != nil { + t.Fatal(err) + } + if res.FinishReason != "interrupted" { + t.Errorf("expected finish reason 'interrupted', got %q", res.FinishReason) + } + if res.FinishMessage != "One or more tool calls resulted in interrupts." { + t.Errorf("unexpected finish message: %q", res.FinishMessage) + } + + if len(res.Message.Content) != 1 { + t.Fatalf("expected 1 content part, got %d", len(res.Message.Content)) + } + + metadata := res.Message.Content[0].Metadata + if metadata == nil { + t.Fatal("expected metadata in content part") + } + + interrupt, ok := metadata["interrupt"].(map[string]any) + if !ok { + t.Fatal("expected interrupt metadata") + } + + reason, ok := interrupt["reason"].(string) + if !ok || reason != "test interrupt" { + t.Errorf("expected interrupt reason 'test interrupt', got %v", reason) + } + }) + + t.Run("handles multiple parallel tool calls", func(t *testing.T) { + roundCount := 0 + parallelModel := DefineModel(r, "test", "parallel", nil, + func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + roundCount++ + if roundCount == 1 { + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "gablorken", + Input: map[string]any{"Value": 2, "Over": 3}, + }), + NewToolRequestPart(&ToolRequest{ + Name: "gablorken", + Input: map[string]any{"Value": 3, "Over": 2}, + }), + }, + }, + }, nil + } + var sum float64 + for _, msg := range gr.Messages { + if msg.Role == RoleTool { + for _, part := range msg.Content { + if part.ToolResponse != nil { + sum += part.ToolResponse.Output.(float64) + } + } + } + } + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart(fmt.Sprintf("Final result: %d", int(sum))), + }, + }, + }, nil + }) + + res, err := Generate(context.Background(), r, + WithModel(parallelModel), + WithTextPrompt("trigger parallel tools"), + WithTools(gablorkenTool), + ) + if err != nil { + t.Fatal(err) + } + + finalPart := res.Message.Content[0] + if finalPart.Text != "Final result: 17" { + t.Errorf("expected final result text to be 'Final result: 17', got %q", finalPart.Text) + } + }) + + t.Run("handles multiple rounds of tool calls", func(t *testing.T) { + roundCount := 0 + multiRoundModel := DefineModel(r, "test", "multiround", nil, + func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + roundCount++ + if roundCount == 1 { + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "gablorken", + Input: map[string]any{"Value": 2, "Over": 3}, + }), + }, + }, + }, nil + } + if roundCount == 2 { + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "gablorken", + Input: map[string]any{"Value": 3, "Over": 2}, + }), + }, + }, + }, nil + } + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewTextPart("Final result"), + }, + }, + }, nil + }) + + res, err := Generate(context.Background(), r, + WithModel(multiRoundModel), + WithTextPrompt("trigger multiple rounds"), + WithTools(gablorkenTool), + WithMaxTurns(2), + ) + if err != nil { + t.Fatal(err) + } + + if roundCount != 3 { + t.Errorf("expected 3 rounds, got %d", roundCount) + } + + if res.Text() != "Final result" { + t.Errorf("expected final message 'Final result', got %q", res.Text()) + } + }) + + t.Run("exceeds maximum turns", func(t *testing.T) { + infiniteModel := DefineModel(r, "test", "infinite", nil, + func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + return &ModelResponse{ + Request: gr, + Message: &Message{ + Role: RoleModel, + Content: []*Part{ + NewToolRequestPart(&ToolRequest{ + Name: "gablorken", + Input: map[string]any{"Value": 2, "Over": 2}, + }), + }, + }, + }, nil + }) + + _, err := Generate(context.Background(), r, + WithModel(infiniteModel), + WithTextPrompt("trigger infinite loop"), + WithTools(gablorkenTool), + WithMaxTurns(2), + ) + + if err == nil { + t.Fatal("expected error for exceeding maximum turns") + } + if !strings.Contains(err.Error(), "exceeded maximum tool call iterations (2)") { + t.Errorf("unexpected error message: %v", err) + } + }) } func TestModelVersion(t *testing.T) { diff --git a/go/ai/tools.go b/go/ai/tools.go index 6519cd1ab..912a2af71 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -37,18 +36,51 @@ type Tool interface { Action() action.Action // RunRaw runs this tool using the provided raw map format data (JSON parsed // as map[string]any). - RunRaw(ctx context.Context, input map[string]any) (any, error) + RunRaw(ctx context.Context, input any) (any, error) +} + +// ToolInterruptError represents an intentional interruption of tool execution. +type ToolInterruptError struct { + Metadata map[string]any +} + +func (e *ToolInterruptError) Error() string { + return "tool execution interrupted" +} + +// InterruptOptions provides configuration for tool interruption. +type InterruptOptions struct { + Metadata map[string]any } -// DefineTool defines a tool function. -func DefineTool[In, Out any](r *registry.Registry, name, description string, fn func(ctx context.Context, input In) (Out, error)) *ToolDef[In, Out] { +// ToolContext provides context and utility functions for tool execution. +type ToolContext struct { + context.Context + Interrupt func(opts *InterruptOptions) error +} + +// DefineTool defines a tool function with interrupt capability +func DefineTool[In, Out any](r *registry.Registry, name, description string, + fn func(ctx *ToolContext, input In) (Out, error)) *ToolDef[In, Out] { + metadata := make(map[string]any) metadata["type"] = "tool" metadata["name"] = name metadata["description"] = description - toolAction := core.DefineAction(r, provider, name, atype.Tool, metadata, fn) + wrappedFn := func(ctx context.Context, input In) (Out, error) { + toolCtx := &ToolContext{ + Context: ctx, + Interrupt: func(opts *InterruptOptions) error { + return &ToolInterruptError{ + Metadata: opts.Metadata, + } + }, + } + return fn(toolCtx, input) + } + toolAction := core.DefineAction(r, provider, name, atype.Tool, metadata, wrappedFn) return &ToolDef[In, Out]{ action: toolAction, } @@ -75,41 +107,46 @@ func (ta *toolAction) Definition() *ToolDefinition { } func definition(ta Tool) *ToolDefinition { - return &ToolDefinition{ - Name: ta.Action().Desc().Metadata["name"].(string), - Description: ta.Action().Desc().Metadata["description"].(string), - InputSchema: base.SchemaAsMap(ta.Action().Desc().InputSchema), - OutputSchema: base.SchemaAsMap(ta.Action().Desc().OutputSchema), + td := &ToolDefinition{ + Name: ta.Action().Desc().Metadata["name"].(string), + Description: ta.Action().Desc().Metadata["description"].(string), + } + if ta.Action().Desc().InputSchema != nil { + td.InputSchema = base.SchemaAsMap(ta.Action().Desc().InputSchema) + } + if ta.Action().Desc().OutputSchema != nil { + td.OutputSchema = base.SchemaAsMap(ta.Action().Desc().OutputSchema) } + return td } // RunRaw runs this tool using the provided raw map format data (JSON parsed // as map[string]any). -func (ta *toolAction) RunRaw(ctx context.Context, input map[string]any) (any, error) { +func (ta *toolAction) RunRaw(ctx context.Context, input any) (any, error) { return runAction(ctx, ta, input) } // RunRaw runs this tool using the provided raw map format data (JSON parsed // as map[string]any). -func (ta *ToolDef[In, Out]) RunRaw(ctx context.Context, input map[string]any) (any, error) { +func (ta *ToolDef[In, Out]) RunRaw(ctx context.Context, input any) (any, error) { return runAction(ctx, ta, input) } -func runAction(ctx context.Context, action Tool, input map[string]any) (any, error) { +func runAction(ctx context.Context, action Tool, input any) (any, error) { mi, err := json.Marshal(input) if err != nil { return nil, fmt.Errorf("error marshalling tool input for %v: %v", action.Definition().Name, err) } output, err := action.Action().RunJSON(ctx, mi, nil) if err != nil { - return nil, fmt.Errorf("error calling tool %v: %v", action.Definition().Name, err) + return nil, fmt.Errorf("error calling tool %v: %w", action.Definition().Name, err) } var uo any err = json.Unmarshal(output, &uo) if err != nil { - return nil, fmt.Errorf("error parsing tool input for %v: %v", action.Definition().Name, err) + return nil, fmt.Errorf("error parsing tool output for %v: %v", action.Definition().Name, err) } return uo, nil } diff --git a/go/core/action.go b/go/core/action.go index abb14ed3d..89d360a82 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package core import ( @@ -142,6 +141,10 @@ func newAction[In, Out, Stream any]( inputSchema = base.InferJSONSchema(i) } } + var outputSchema *jsonschema.Schema + if reflect.ValueOf(o).Kind() != reflect.Invalid { + outputSchema = base.InferJSONSchema(o) + } return &Action[In, Out, Stream]{ name: name, atype: atype, @@ -150,7 +153,7 @@ func newAction[In, Out, Stream any]( return fn(ctx, input, sc) }, inputSchema: inputSchema, - outputSchema: base.InferJSONSchema(o), + outputSchema: outputSchema, metadata: metadata, } } diff --git a/go/core/schemas.config b/go/core/schemas.config index ecf4b8784..2c57159e1 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -96,23 +96,23 @@ RoleTool indicates this message was generated by a local tool, likely triggered from the model in one of its previous responses. . -ToolRequestPart omit -ToolRequestPartToolRequest name ToolRequest -ToolResponsePart omit -ToolResponsePartToolResponse name ToolResponse - -ToolRequestPartToolRequest.input type map[string]any +ToolRequestPart pkg ai +ToolRequestPart name toolRequestPart +ToolRequestPartToolRequest name ToolRequest +ToolResponsePart pkg ai +ToolResponsePart name toolResponsePart +ToolResponsePartToolResponse name ToolResponse + +ToolRequestPartToolRequest.input type any ToolRequestPartToolRequest.input doc Input is a JSON object describing the input values to the tool. An example might be map[string]any{"country":"USA", "president":3}. . -ToolResponsePartToolResponse.output type map[string]any +ToolResponsePartToolResponse.output type any ToolResponsePartToolResponse.output doc Output is a JSON object describing the results of running the tool. An example might be map[string]any{"name":"Thomas Jefferson", "born":1743}. . -ToolRequestPartToolRequest.ref omit -ToolResponsePartToolResponse.ref omit ToolRequestPartToolRequest doc A ToolRequest is a message from the model to the client that it should run a @@ -135,6 +135,7 @@ GenerateRequest omit GenerateRequestOutput pkg ai GenerateRequestOutput name ModelRequestOutput GenerateRequestOutputFormat pkg ai +GenerateRequestToolChoice omit GenerationUsage pkg ai GenerationUsage.inputCharacters type int GenerationUsage.inputImages type int @@ -167,6 +168,14 @@ MediaPart.toolRequest omit MediaPart.toolResponse omit MediaPartMedia pkg ai MediaPartMedia name mediaPartMedia +ToolRequestPart.text omit +ToolRequestPart.media omit +ToolRequestPart.data omit +ToolRequestPart.toolResponse omit +ToolResponsePart.text omit +ToolResponsePart.media omit +ToolResponsePart.data omit +ToolResponsePart.toolRequest omit DataPart pkg ai DataPart name dataPart ModelInfo pkg ai @@ -177,6 +186,22 @@ RoleUser pkg ai RoleModel pkg ai RoleTool pkg ai +# GenerateActionOptions +GenerateActionOptions pkg ai +GenerateActionOptions.model type string +GenerateActionOptions.docs type []*Document +GenerateActionOptions.messages type []*Message +GenerateActionOptions.tools type []*ToolDefinition +GenerateActionOptionsToolChoice name ToolChoice +GenerateActionOptions.config type any +GenerateActionOptions.output type *GenerateActionOptionsOutput +GenerateActionOptions.returnToolRequests type bool +GenerateActionOptions.maxTurns type int + +GenerateActionOptionsOutput.instructions omit +GenerateActionOptionsOutput.format type OutputFormat +GenerateActionOptionsOutput.jsonSchema type map[string]any + # ModelRequest ModelRequest pkg ai ModelRequest.config type any @@ -184,6 +209,7 @@ ModelRequest.context type []any ModelRequest.messages type []*Message ModelRequest.output type *ModelRequestOutput ModelRequest.tools type []*ToolDefinition +ModelRequest.toolChoice type ToolChoice # ModelResponse ModelResponse pkg ai @@ -201,6 +227,8 @@ ModelResponseChunk pkg ai ModelResponseChunk.aggregated type bool ModelResponseChunk.content type []*Part ModelResponseChunk.custom type any +ModelResponseChunk.index type int +ModelResponseChunk.role type Role GenerationCommonConfig doc GenerationCommonConfig holds configuration for generation. diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index dc0e8aac8..73e7c9241 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -86,6 +86,8 @@ func New(opts *Options) (*Genkit, error) { // Thus Start(nil) will start a dev server in the "dev" environment, will always start // a flow server, and will pause execution until the flow server terminates. func (g *Genkit) Start(ctx context.Context, opts *StartOptions) error { + ai.DefineGenerateAction(ctx, g.reg) + if opts == nil { opts = &StartOptions{} } @@ -174,7 +176,7 @@ func LookupModel(g *Genkit, provider, name string) ai.Model { } // DefineTool defines a tool to be passed to a model generate call. -func DefineTool[In, Out any](g *Genkit, name, description string, fn func(ctx context.Context, input In) (Out, error)) *ai.ToolDef[In, Out] { +func DefineTool[In, Out any](g *Genkit, name, description string, fn func(ctx *ai.ToolContext, input In) (Out, error)) *ai.ToolDef[In, Out] { return ai.DefineTool(g.reg, name, description, fn) } @@ -237,8 +239,8 @@ func GenerateData(ctx context.Context, g *Genkit, value any, opts ...ai.Generate } // GenerateWithRequest runs the model with the given request and streaming callback. -func GenerateWithRequest(ctx context.Context, g *Genkit, m ai.Model, req *ai.ModelRequest, cb ai.ModelStreamingCallback) (*ai.ModelResponse, error) { - return m.Generate(ctx, g.reg, req, cb) +func GenerateWithRequest(ctx context.Context, g *Genkit, m ai.Model, req *ai.ModelRequest, toolCfg *ai.ToolConfig, cb ai.ModelStreamingCallback) (*ai.ModelResponse, error) { + return m.Generate(ctx, g.reg, req, toolCfg, cb) } // DefineIndexer registers the given index function as an action, and returns an diff --git a/go/go.mod b/go/go.mod index 502f4c3de..9e1fb4f0e 100644 --- a/go/go.mod +++ b/go/go.mod @@ -3,8 +3,8 @@ module github.com/firebase/genkit/go go 1.22.0 retract ( - v0.1.3 // This shold have been a minor release. v0.1.4 // Retraction only. + v0.1.3 // This shold have been a minor release. ) require ( diff --git a/go/internal/atype/atype.go b/go/internal/atype/atype.go index da4f54d50..baeae4e9b 100644 --- a/go/internal/atype/atype.go +++ b/go/internal/atype/atype.go @@ -19,5 +19,6 @@ const ( Model ActionType = "model" Prompt ActionType = "prompt" Tool ActionType = "tool" + Util ActionType = "util" Custom ActionType = "custom" ) diff --git a/go/internal/base/validation.go b/go/internal/base/validation.go index cd24bedc8..aa4141735 100644 --- a/go/internal/base/validation.go +++ b/go/internal/base/validation.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package base import ( @@ -16,6 +15,9 @@ import ( // ValidateValue will validate any value against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. func ValidateValue(data any, schema *jsonschema.Schema) error { + if schema == nil { + return nil + } dataBytes, err := json.Marshal(data) if err != nil { return fmt.Errorf("data is not a valid JSON type: %w", err) @@ -26,6 +28,9 @@ func ValidateValue(data any, schema *jsonschema.Schema) error { // ValidateJSON will validate JSON against the expected schema. // It will return an error if it doesn't match the schema, otherwise it will return nil. func ValidateJSON(dataBytes json.RawMessage, schema *jsonschema.Schema) error { + if schema == nil { + return nil + } schemaBytes, err := schema.MarshalJSON() if err != nil { return fmt.Errorf("expected schema is not valid: %w", err) diff --git a/go/internal/doc-snippets/models.go b/go/internal/doc-snippets/models.go index 6639afa17..219765b30 100644 --- a/go/internal/doc-snippets/models.go +++ b/go/internal/doc-snippets/models.go @@ -140,7 +140,7 @@ func tools() error { g, "myJoke", "useful when you need a joke to tell", - func(ctx context.Context, input *any) (string, error) { + func(ctx *ai.ToolContext, input *any) (string, error) { return "haha Just kidding no joke! got you", nil }, ) diff --git a/go/internal/doc-snippets/prompts.go b/go/internal/doc-snippets/prompts.go index 7944f977c..d39636599 100644 --- a/go/internal/doc-snippets/prompts.go +++ b/go/internal/doc-snippets/prompts.go @@ -95,7 +95,7 @@ func pr03() error { if err != nil { return err } - response, err := genkit.GenerateWithRequest(context.Background(), g, model, request, nil) + response, err := genkit.GenerateWithRequest(context.Background(), g, model, request, nil, nil) // [END pr03_2] _ = response diff --git a/go/plugins/dotprompt/dotprompt.go b/go/plugins/dotprompt/dotprompt.go index f7850f4e2..243e7a701 100644 --- a/go/plugins/dotprompt/dotprompt.go +++ b/go/plugins/dotprompt/dotprompt.go @@ -44,18 +44,13 @@ type Prompt struct { // The name of the prompt. Optional unless the prompt is // registered as an action. Name string - Config - // The parsed prompt template. Template *raymond.Template - // The original prompt template text. TemplateText string - // A hash of the prompt contents. hash string - // A prompt that renders the prompt. prompt *ai.Prompt } @@ -67,31 +62,29 @@ type Config struct { // The name of the model for which the prompt is input. // If this is non-empty, Model should be nil. ModelName string - // The Model to use. // If this is set, ModelName should be an empty string. Model ai.Model - // TODO: document Tools []ai.Tool - // Details for the model. GenerationConfig *ai.GenerationCommonConfig - // Schema for input variables. InputSchema *jsonschema.Schema - // Default input variable values DefaultInput map[string]any - // Desired output format. OutputFormat ai.OutputFormat - // Desired output schema, for JSON output. OutputSchema *jsonschema.Schema - // Arbitrary metadata. Metadata map[string]any + // ToolChoice is the tool choice to use. + ToolChoice ai.ToolChoice + // MaxTurns is the maximum number of turns. + MaxTurns int + // ReturnToolRequests is whether to return tool requests. + ReturnToolRequests bool } // PromptOption configures params for the prompt @@ -153,7 +146,10 @@ type frontmatterYAML struct { Format string `yaml:"format,omitempty"` Schema any `yaml:"schema,omitempty"` } `yaml:"output,omitempty"` - Metadata map[string]any `yaml:"metadata,omitempty"` + Metadata map[string]any `yaml:"metadata,omitempty"` + ToolChoice string `yaml:"toolChoice,omitempty"` + MaxTurns int `yaml:"maxTurns,omitempty"` + ReturnToolRequests bool `yaml:"returnToolRequests,omitempty"` } // Parse parses the contents of a dotprompt file. @@ -214,12 +210,15 @@ func parseFrontmatter(g *genkit.Genkit, data []byte) (name string, c Config, res } ret := Config{ - Variant: fy.Variant, - ModelName: fy.Model, - Tools: tools, - GenerationConfig: fy.Config, - DefaultInput: fy.Input.Default, - Metadata: fy.Metadata, + Variant: fy.Variant, + ModelName: fy.Model, + Tools: tools, + GenerationConfig: fy.Config, + DefaultInput: fy.Input.Default, + Metadata: fy.Metadata, + ToolChoice: ai.ToolChoice(fy.ToolChoice), + MaxTurns: fy.MaxTurns, + ReturnToolRequests: fy.ReturnToolRequests, } inputSchema, err := picoschemaToJSONSchema(fy.Input.Schema) @@ -315,7 +314,7 @@ func sortSchemaSlices(s *jsonschema.Schema) { func WithTools(tools ...ai.Tool) PromptOption { return func(p *Prompt) error { if p.Config.Tools != nil { - return errors.New("dotprompt.WithTools: cannot set tools more than once") + return errors.New("dotprompt.WithTools: cannot set Tools more than once") } var toolSlice []ai.Tool @@ -439,3 +438,39 @@ func WithDefaultModelName(name string) PromptOption { return nil } } + +// WithDefaultMaxTurns sets the default maximum number of tool call iterations for the prompt. +func WithDefaultMaxTurns(maxTurns int) PromptOption { + return func(p *Prompt) error { + if maxTurns <= 0 { + return fmt.Errorf("maxTurns must be greater than 0, got %d", maxTurns) + } + if p.Config.MaxTurns != 0 { + return errors.New("dotprompt.WithMaxTurns: cannot set MaxTurns more than once") + } + p.Config.MaxTurns = maxTurns + return nil + } +} + +// WithDefaultReturnToolRequests configures whether by default to return tool requests instead of making the tool calls and continuing the generation. +func WithDefaultReturnToolRequests(returnToolRequests bool) PromptOption { + return func(p *Prompt) error { + if p.Config.ReturnToolRequests { + return errors.New("dotprompt.WithReturnToolRequests: cannot set ReturnToolRequests more than once") + } + p.Config.ReturnToolRequests = returnToolRequests + return nil + } +} + +// WithDefaultToolChoice configures whether by default tool calls are required, disabled, or optional for the prompt. +func WithDefaultToolChoice(toolChoice ai.ToolChoice) PromptOption { + return func(p *Prompt) error { + if p.Config.ToolChoice != "" { + return errors.New("dotprompt.WithToolChoice: cannot set ToolChoice more than once") + } + p.Config.ToolChoice = toolChoice + return nil + } +} diff --git a/go/plugins/dotprompt/dotprompt_test.go b/go/plugins/dotprompt/dotprompt_test.go index ee4590e4b..44a45d8a5 100644 --- a/go/plugins/dotprompt/dotprompt_test.go +++ b/go/plugins/dotprompt/dotprompt_test.go @@ -5,7 +5,6 @@ package dotprompt import ( - "context" "encoding/json" "log" "testing" @@ -22,7 +21,7 @@ type InputOutput struct { func testTool(g *genkit.Genkit, name string) *ai.ToolDef[struct{ Test string }, string] { return genkit.DefineTool(g, name, "use when need to execute a test", - func(ctx context.Context, input struct { + func(ctx *ai.ToolContext, input struct { Test string }) (string, error) { return input.Test, nil diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index 0184a1e8e..eec20b4ba 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -32,6 +32,14 @@ type PromptRequest struct { ModelName string `json:"modelname,omitempty"` // Streaming callback function Stream ai.ModelStreamingCallback + // Maximum number of tool call iterations for the prompt. + MaxTurns int `json:"maxTurns,omitempty"` + // Whether to return tool requests instead of making the tool calls and continuing the generation. + ReturnToolRequests bool `json:"returnToolRequests,omitempty"` + // Whether the ReturnToolRequests field was set (false is not enough information as to whether to override). + IsReturnToolRequestsSet bool `json:"-"` + // Whether tool calls are required, disabled, or optional for the prompt. + ToolChoice ai.ToolChoice `json:"toolChoice,omitempty"` } // GenerateOption configures params for Generate function @@ -106,6 +114,7 @@ func (p *Prompt) buildRequest(ctx context.Context, input any) (*ai.ModelRequest, } req.Config = p.GenerationConfig + req.ToolChoice = p.ToolChoice var outputSchema map[string]any if p.OutputSchema != nil { @@ -203,6 +212,9 @@ func (p *Prompt) Generate(ctx context.Context, g *genkit.Genkit, opts ...Generat if len(pr.Context) > 0 { mr.Context = pr.Context } + if pr.ToolChoice != "" { + mr.ToolChoice = pr.ToolChoice + } // Setting the model on generate, overrides the model defined on the prompt var model ai.Model @@ -232,7 +244,22 @@ func (p *Prompt) Generate(ctx context.Context, g *genkit.Genkit, opts ...Generat } } - resp, err := genkit.GenerateWithRequest(ctx, g, model, mr, pr.Stream) + maxTurns := p.Config.MaxTurns + if pr.MaxTurns != 0 { + maxTurns = pr.MaxTurns + } + + returnToolRequests := p.Config.ReturnToolRequests + if pr.IsReturnToolRequestsSet { + returnToolRequests = pr.ReturnToolRequests + } + + toolCfg := &ai.ToolConfig{ + MaxTurns: maxTurns, + ReturnToolRequests: returnToolRequests, + } + + resp, err := genkit.GenerateWithRequest(ctx, g, model, mr, toolCfg, pr.Stream) if err != nil { return nil, err } @@ -327,11 +354,48 @@ func WithModelName(model string) GenerateOption { // WithStreaming adds a streaming callback to the generate request. func WithStreaming(cb ai.ModelStreamingCallback) GenerateOption { - return func(g *PromptRequest) error { - if g.Stream != nil { + return func(p *PromptRequest) error { + if p.Stream != nil { return errors.New("dotprompt.WithStreaming: cannot set Stream more than once") } - g.Stream = cb + p.Stream = cb + return nil + } +} + +// WithMaxTurns sets the maximum number of tool call iterations for the prompt. +func WithMaxTurns(maxTurns int) GenerateOption { + return func(p *PromptRequest) error { + if maxTurns <= 0 { + return fmt.Errorf("maxTurns must be greater than 0, got %d", maxTurns) + } + if p.MaxTurns != 0 { + return errors.New("dotprompt.WithMaxTurns: cannot set MaxTurns more than once") + } + p.MaxTurns = maxTurns + return nil + } +} + +// WithReturnToolRequests configures whether to return tool requests instead of making the tool calls and continuing the generation. +func WithReturnToolRequests(returnToolRequests bool) GenerateOption { + return func(p *PromptRequest) error { + if p.IsReturnToolRequestsSet { + return errors.New("dotprompt.WithReturnToolRequests: cannot set ReturnToolRequests more than once") + } + p.ReturnToolRequests = returnToolRequests + p.IsReturnToolRequestsSet = true + return nil + } +} + +// WithToolChoice configures whether tool calls are required, disabled, or optional for the prompt. +func WithToolChoice(toolChoice ai.ToolChoice) GenerateOption { + return func(p *PromptRequest) error { + if p.ToolChoice != "" { + return errors.New("dotprompt.WithToolChoice: cannot set ToolChoice more than once") + } + p.ToolChoice = toolChoice return nil } } diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index 1a6080bca..aed57b86f 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -570,16 +570,33 @@ func convertPart(p *ai.Part) (genai.Part, error) { panic(fmt.Sprintf("%s does not support Data parts", provider)) case p.IsToolResponse(): toolResp := p.ToolResponse + var output map[string]any + if m, ok := toolResp.Output.(map[string]any); ok { + output = m + } else { + output = map[string]any{ + "name": toolResp.Name, + "content": toolResp.Output, + } + } fr := genai.FunctionResponse{ Name: toolResp.Name, - Response: toolResp.Output, + Response: output, } return fr, nil case p.IsToolRequest(): toolReq := p.ToolRequest + var input map[string]any + if m, ok := toolReq.Input.(map[string]any); ok { + input = m + } else { + input = map[string]any{ + "input": toolReq.Input, + } + } fc := genai.FunctionCall{ Name: toolReq.Name, - Args: toolReq.Input, + Args: input, } return fc, nil default: diff --git a/go/plugins/googleai/googleai_test.go b/go/plugins/googleai/googleai_test.go index 59199ed3f..b24089169 100644 --- a/go/plugins/googleai/googleai_test.go +++ b/go/plugins/googleai/googleai_test.go @@ -55,7 +55,7 @@ func TestLive(t *testing.T) { t.Fatal(err) } gablorkenTool := genkit.DefineTool(g, "gablorken", "use when need to calculate a gablorken", - func(ctx context.Context, input struct { + func(ctx *ai.ToolContext, input struct { Value float64 Over float64 }) (float64, error) { diff --git a/go/plugins/ollama/ollama_live_test.go b/go/plugins/ollama/ollama_live_test.go index 4b4bb5c9d..a7e11fd50 100644 --- a/go/plugins/ollama/ollama_live_test.go +++ b/go/plugins/ollama/ollama_live_test.go @@ -57,7 +57,7 @@ func TestLive(t *testing.T) { ai.NewModelRequest( &ai.GenerationCommonConfig{Temperature: 1}, ai.NewUserTextMessage("I'm hungry, what should I eat?")), - nil) + nil, nil) if err != nil { t.Fatalf("failed to generate response: %s", err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 186b4a7bd..15580566a 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -587,16 +587,33 @@ func convertPart(p *ai.Part) (genai.Part, error) { panic(fmt.Sprintf("%s does not support Data parts", provider)) case p.IsToolResponse(): toolResp := p.ToolResponse + var output map[string]any + if m, ok := toolResp.Output.(map[string]any); ok { + output = m + } else { + output = map[string]any{ + "name": toolResp.Name, + "content": toolResp.Output, + } + } fr := genai.FunctionResponse{ Name: toolResp.Name, - Response: toolResp.Output, + Response: output, } return fr, nil case p.IsToolRequest(): toolReq := p.ToolRequest + var input map[string]any + if m, ok := toolReq.Input.(map[string]any); ok { + input = m + } else { + input = map[string]any{ + "input": toolReq.Input, + } + } fc := genai.FunctionCall{ Name: toolReq.Name, - Args: toolReq.Input, + Args: input, } return fc, nil default: diff --git a/go/plugins/vertexai/vertexai_test.go b/go/plugins/vertexai/vertexai_test.go index bd9de850a..4610b192a 100644 --- a/go/plugins/vertexai/vertexai_test.go +++ b/go/plugins/vertexai/vertexai_test.go @@ -40,7 +40,7 @@ func TestLive(t *testing.T) { embedder := vertexai.Embedder(g, "textembedding-gecko@003") gablorkenTool := genkit.DefineTool(g, "gablorken", "use when need to calculate a gablorken", - func(ctx context.Context, input struct { + func(ctx *ai.ToolContext, input struct { Value float64 Over float64 }) (float64, error) { diff --git a/go/samples/menu/s02.go b/go/samples/menu/s02.go index 6dc93eaa4..caba647b3 100644 --- a/go/samples/menu/s02.go +++ b/go/samples/menu/s02.go @@ -13,7 +13,7 @@ import ( "github.com/firebase/genkit/go/plugins/dotprompt" ) -func menu(ctx context.Context, _ *any) ([]*menuItem, error) { +func menu(ctx *ai.ToolContext, _ *any) ([]*menuItem, error) { f, err := os.Open("testdata/menu.json") if err != nil { return nil, err diff --git a/go/samples/menu/s03.go b/go/samples/menu/s03.go index 837807c47..66e02ae91 100644 --- a/go/samples/menu/s03.go +++ b/go/samples/menu/s03.go @@ -71,7 +71,7 @@ func setup03(g *genkit.Genkit, m ai.Model) error { return err } - menuData, err := menu(context.Background(), nil) + menuData, err := menu(&ai.ToolContext{Context: context.Background()}, nil) if err != nil { return err }