Skip to content

Commit

Permalink
feat(go): Added /util/generate + filled feature gaps in Generate AP…
Browse files Browse the repository at this point in the history
…I. (#1818)
  • Loading branch information
apascal07 authored Feb 8, 2025
1 parent 53c5ad2 commit 8dc9ed8
Show file tree
Hide file tree
Showing 26 changed files with 1,067 additions and 193 deletions.
2 changes: 2 additions & 0 deletions genkit-tools/common/src/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ export const ModelInfoSchema = z.object({
context: z.boolean().optional(),
/** Model can natively support constrained generation. */
constrained: z.enum(['none', 'all', 'no-tools']).optional(),
/** Model supports controlling tool choice, e.g. forced tool calling. */
toolChoice: z.boolean().optional(),
})
.optional(),
});
Expand Down
3 changes: 3 additions & 0 deletions genkit-tools/genkit-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,9 @@
"all",
"no-tools"
]
},
"toolChoice": {
"type": "boolean"
}
},
"additionalProperties": false
Expand Down
145 changes: 145 additions & 0 deletions go/ai/action_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"
"os"
"testing"

"github.com/firebase/genkit/go/internal/registry"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gopkg.in/yaml.v3"
)

type specSuite struct {
Tests []testCase `yaml:"tests"`
}

type testCase struct {
Name string `yaml:"name"`
Input *GenerateActionOptions `yaml:"input"`
StreamChunks [][]*ModelResponseChunk `yaml:"streamChunks,omitempty"`
ModelResponses []*ModelResponse `yaml:"modelResponses"`
ExpectResponse *ModelResponse `yaml:"expectResponse,omitempty"`
Stream bool `yaml:"stream,omitempty"`
ExpectChunks []*ModelResponseChunk `yaml:"expectChunks,omitempty"`
}

type programmableModel struct {
r *registry.Registry
handleResp func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error)
lastRequest *ModelRequest
}

func (pm *programmableModel) Name() string {
return "programmableModel"
}

func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) {
pm.lastRequest = req
return pm.handleResp(ctx, req, cb)
}

func defineProgrammableModel(r *registry.Registry) *programmableModel {
pm := &programmableModel{r: r}
DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
}

func TestGenerateAction(t *testing.T) {
data, err := os.ReadFile("../../tests/specs/generate.yaml")
if err != nil {
t.Fatalf("failed to read spec file: %v", err)
}

var suite specSuite
if err := yaml.Unmarshal(data, &suite); err != nil {
t.Fatalf("failed to parse spec file: %v", err)
}

for _, tc := range suite.Tests {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

r, err := registry.New()
if err != nil {
t.Fatalf("failed to create registry: %v", err)
}

pm := defineProgrammableModel(r)

DefineTool(r, "testTool", "description",
func(ctx *ToolContext, input any) (any, error) {
return "tool called", nil
})

if len(tc.ModelResponses) > 0 || len(tc.StreamChunks) > 0 {
reqCounter := 0
pm.handleResp = func(ctx context.Context, req *ModelRequest, cb func(context.Context, *ModelResponseChunk) error) (*ModelResponse, error) {
if len(tc.StreamChunks) > 0 && cb != nil {
for _, chunk := range tc.StreamChunks[reqCounter] {
if err := cb(ctx, chunk); err != nil {
return nil, err
}
}
}
resp := tc.ModelResponses[reqCounter]
resp.Request = req
resp.Custom = map[string]any{}
resp.Request.Output = &ModelRequestOutput{}
resp.Usage = &GenerationUsage{}
reqCounter++
return resp, nil
}
}

genAction := DefineGenerateAction(ctx, r)

if tc.Stream {
chunks := []*ModelResponseChunk{}
streamCb := func(ctx context.Context, chunk *ModelResponseChunk) error {
chunks = append(chunks, chunk)
return nil
}

resp, err := genAction.Run(ctx, tc.Input, streamCb)
if err != nil {
t.Fatalf("action failed: %v", err)
}

if diff := cmp.Diff(tc.ExpectChunks, chunks); diff != "" {
t.Errorf("chunks mismatch (-want +got):\n%s", diff)
}

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
} else {
resp, err := genAction.Run(ctx, tc.Input, nil)
if err != nil {
t.Fatalf("action failed: %v", err)
}

if diff := cmp.Diff(tc.ExpectResponse, resp, cmp.Options{cmpopts.EquateEmpty()}); diff != "" {
t.Errorf("response mismatch (-want +got):\n%s", diff)
}
}
})
}
}
89 changes: 51 additions & 38 deletions go/ai/document.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
"encoding/json"
"fmt"

"gopkg.in/yaml.v3"
)

// A Document is a piece of data that can be embedded, indexed, or retrieved.
Expand All @@ -21,11 +22,12 @@ type Document struct {
// A Part is one part of a [Document]. This may be plain text or it
// may be a URL (possibly a "data:" URL with embedded data).
type Part struct {
Kind PartKind `json:"kind,omitempty"`
ContentType string `json:"contentType,omitempty"` // valid for kind==blob
Text string `json:"text,omitempty"` // valid for kind∈{text,blob}
ToolRequest *ToolRequest `json:"toolreq,omitempty"` // valid for kind==partToolRequest
ToolResponse *ToolResponse `json:"toolresp,omitempty"` // valid for kind==partToolResponse
Kind PartKind `json:"kind,omitempty"`
ContentType string `json:"contentType,omitempty"` // valid for kind==blob
Text string `json:"text,omitempty"` // valid for kind∈{text,blob}
ToolRequest *ToolRequest `json:"toolRequest,omitempty"` // valid for kind==partToolRequest
ToolResponse *ToolResponse `json:"toolResponse,omitempty"` // valid for kind==partToolResponse
Metadata map[string]any `json:"metadata,omitempty"` // valid for all kinds
}

type PartKind int8
Expand Down Expand Up @@ -105,7 +107,8 @@ func (p *Part) MarshalJSON() ([]byte, error) {
switch p.Kind {
case PartText:
v := textPart{
Text: p.Text,
Text: p.Text,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartMedia:
Expand All @@ -114,28 +117,25 @@ func (p *Part) MarshalJSON() ([]byte, error) {
ContentType: p.ContentType,
Url: p.Text,
},
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartData:
v := dataPart{
Data: p.Text,
Data: p.Text,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartToolRequest:
// TODO: make sure these types marshal/unmarshal nicely
// between Go and javascript. At the very least the
// field name needs to change (here and in UnmarshalJSON).
v := struct {
ToolReq *ToolRequest `json:"toolreq,omitempty"`
}{
ToolReq: p.ToolRequest,
v := toolRequestPart{
ToolRequest: p.ToolRequest,
Metadata: p.Metadata,
}
return json.Marshal(v)
case PartToolResponse:
v := struct {
ToolResp *ToolResponse `json:"toolresp,omitempty"`
}{
ToolResp: p.ToolResponse,
v := toolResponsePart{
ToolResponse: p.ToolResponse,
Metadata: p.Metadata,
}
return json.Marshal(v)
default:
Expand All @@ -144,34 +144,27 @@ func (p *Part) MarshalJSON() ([]byte, error) {
}

type partSchema struct {
Text string `json:"text,omitempty"`
Media *mediaPartMedia `json:"media,omitempty"`
Data string `json:"data,omitempty"`
ToolReq *ToolRequest `json:"toolreq,omitempty"`
ToolResp *ToolResponse `json:"toolresp,omitempty"`
Text string `json:"text,omitempty" yaml:"text,omitempty"`
Media *mediaPartMedia `json:"media,omitempty" yaml:"media,omitempty"`
Data string `json:"data,omitempty" yaml:"data,omitempty"`
ToolRequest *ToolRequest `json:"toolRequest,omitempty" yaml:"toolRequest,omitempty"`
ToolResponse *ToolResponse `json:"toolResponse,omitempty" yaml:"toolResponse,omitempty"`
Metadata map[string]any `json:"metadata,omitempty" yaml:"metadata,omitempty"`
}

// UnmarshalJSON is called by the JSON unmarshaler to read a Part.
func (p *Part) UnmarshalJSON(b []byte) error {
// This is not handled by the schema generator because
// Part is defined in TypeScript as a union.

var s partSchema
if err := json.Unmarshal(b, &s); err != nil {
return err
}

// unmarshalPartFromSchema updates Part p based on the schema s.
func (p *Part) unmarshalPartFromSchema(s partSchema) {
switch {
case s.Media != nil:
p.Kind = PartMedia
p.Text = s.Media.Url
p.ContentType = s.Media.ContentType
case s.ToolReq != nil:
case s.ToolRequest != nil:
p.Kind = PartToolRequest
p.ToolRequest = s.ToolReq
case s.ToolResp != nil:
p.ToolRequest = s.ToolRequest
case s.ToolResponse != nil:
p.Kind = PartToolResponse
p.ToolResponse = s.ToolResp
p.ToolResponse = s.ToolResponse
default:
p.Kind = PartText
p.Text = s.Text
Expand All @@ -182,6 +175,26 @@ func (p *Part) UnmarshalJSON(b []byte) error {
p.Text = s.Data
}
}
p.Metadata = s.Metadata
}

// UnmarshalJSON is called by the JSON unmarshaler to read a Part.
func (p *Part) UnmarshalJSON(b []byte) error {
var s partSchema
if err := json.Unmarshal(b, &s); err != nil {
return err
}
p.unmarshalPartFromSchema(s)
return nil
}

// UnmarshalYAML implements yaml.Unmarshaler for Part.
func (p *Part) UnmarshalYAML(value *yaml.Node) error {
var s partSchema
if err := value.Decode(&s); err != nil {
return err
}
p.unmarshalPartFromSchema(s)
return nil
}

Expand Down
Loading

0 comments on commit 8dc9ed8

Please sign in to comment.