Skip to content

fix(go): fix bug where dev UI configs aren't respected #2234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ type ModelInfo struct {
Label string `json:"label,omitempty"`
Supports *ModelInfoSupports `json:"supports,omitempty"`
Versions []string `json:"versions,omitempty"`
ConfigSchema map[string]any `json:"configSchema,omitempty"`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to come from genkit-tools otherwise it will get overwritten on next generation.

}

type ModelInfoSupports struct {
Expand Down
5 changes: 3 additions & 2 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ func DefineModel(
}
metadataMap["supports"] = supports
metadataMap["versions"] = info.Versions

if info.ConfigSchema != nil {
metadataMap["customOptions"] = info.ConfigSchema
}
generate = core.ChainMiddleware(ValidateSupport(name, info))(generate)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, generate))
}

Expand Down
193 changes: 172 additions & 21 deletions go/plugins/googleai/googleai.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package googleai

import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
Expand All @@ -17,9 +18,11 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/genkit"
"github.com/firebase/genkit/go/internal"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/plugins/internal/gemini"
"github.com/firebase/genkit/go/plugins/internal/uri"
"github.com/google/generative-ai-go/genai"
"github.com/invopop/jsonschema"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)
Expand All @@ -36,6 +39,57 @@ var state struct {
initted bool
}

type HarmCategory int32

const (
// HarmCategoryUnspecified means category is unspecified.
HarmCategoryUnspecified HarmCategory = 0
// HarmCategoryDerogatory means negative or harmful comments targeting identity and/or protected attribute.
HarmCategoryDerogatory HarmCategory = 1
// HarmCategoryToxicity means content that is rude, disrespectful, or profane.
HarmCategoryToxicity HarmCategory = 2
// HarmCategoryViolence means describes scenarios depicting violence against an individual or group, or
// general descriptions of gore.
HarmCategoryViolence HarmCategory = 3
// HarmCategorySexual means contains references to sexual acts or other lewd content.
HarmCategorySexual HarmCategory = 4
// HarmCategoryMedical means promotes unchecked medical advice.
HarmCategoryMedical HarmCategory = 5
// HarmCategoryDangerous means dangerous content that promotes, facilitates, or encourages harmful acts.
HarmCategoryDangerous HarmCategory = 6
// HarmCategoryHarassment means harasment content.
HarmCategoryHarassment HarmCategory = 7
// HarmCategoryHateSpeech means hate speech and content.
HarmCategoryHateSpeech HarmCategory = 8
// HarmCategorySexuallyExplicit means sexually explicit content.
HarmCategorySexuallyExplicit HarmCategory = 9
// HarmCategoryDangerousContent means dangerous content.
HarmCategoryDangerousContent HarmCategory = 10
)

// HarmBlockThreshold specifies block at and beyond a specified harm probability.
type HarmBlockThreshold int32

const (
// HarmBlockUnspecified means threshold is unspecified.
HarmBlockUnspecified HarmBlockThreshold = 0
// HarmBlockLowAndAbove means content with NEGLIGIBLE will be allowed.
HarmBlockLowAndAbove HarmBlockThreshold = 1
// HarmBlockMediumAndAbove means content with NEGLIGIBLE and LOW will be allowed.
HarmBlockMediumAndAbove HarmBlockThreshold = 2
// HarmBlockOnlyHigh means content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
HarmBlockOnlyHigh HarmBlockThreshold = 3
// HarmBlockNone means all content will be allowed.
HarmBlockNone HarmBlockThreshold = 4
)

type SafetySetting struct {
// Required. The category for this setting.
Category HarmCategory
// Required. Controls the probability threshold at which harm is blocked.
Threshold HarmBlockThreshold
}

var (
supportedModels = map[string]ai.ModelInfo{
"gemini-1.5-flash": {
Expand Down Expand Up @@ -89,6 +143,12 @@ var (
}
)

// GenerationGoogleAIConfig extends GenerationCommonConfig with Google AI specific settings.
type GenerationGoogleAIConfig struct {
ai.GenerationCommonConfig
SafetySettings []*SafetySetting
}

// Config is the configuration for the plugin.
type Config struct {
// The API key to access the service.
Expand Down Expand Up @@ -179,9 +239,10 @@ func DefineModel(g *genkit.Genkit, name string, info *ai.ModelInfo) (ai.Model, e
// requires state.mu
func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model {
meta := &ai.ModelInfo{
Label: labelPrefix + " - " + name,
Supports: info.Supports,
Versions: info.Versions,
Label: labelPrefix + " - " + name,
Supports: info.Supports,
Versions: info.Versions,
ConfigSchema: convertConfigSchemaToMap(&GenerationGoogleAIConfig{}),
}
return genkit.DefineModel(g, provider, name, meta, func(
ctx context.Context,
Expand All @@ -192,6 +253,16 @@ func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model {
})
}

func convertConfigSchemaToMap(config any) map[string]any {
r := jsonschema.Reflector{
DoNotReference: true, // Prevent $ref usage
ExpandedStruct: true, // Include all fields directly
}
schema := r.Reflect(config)
result := base.SchemaAsMap(schema)
return result
}

// IsDefinedModel reports whether the named [Model] is defined by this plugin.
func IsDefinedModel(g *genkit.Genkit, name string) bool {
return genkit.IsDefinedModel(g, provider, name)
Expand Down Expand Up @@ -338,26 +409,90 @@ func generate(
return r, nil
}

func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) {
gm := client.GenerativeModel(model)
gm.SetCandidateCount(1)
if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil {
if c.MaxOutputTokens != 0 {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
}
if len(c.StopSequences) > 0 {
gm.StopSequences = c.StopSequences
}
if c.Temperature != 0 {
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
func mapToStruct(m map[string]any, v any) error {
jsonData, err := json.Marshal(m)
if err != nil {
return err
}
return json.Unmarshal(jsonData, v)
}

// applyGenerationConfig applies the common generation configuration to the model
func applyGenerationConfig(gm *genai.GenerativeModel, c GenerationGoogleAIConfig) {
if c.MaxOutputTokens != 0 {
gm.SetMaxOutputTokens(int32(c.MaxOutputTokens))
}
if len(c.StopSequences) > 0 {
gm.StopSequences = c.StopSequences
}
if c.Temperature != 0 {
gm.SetTemperature(float32(c.Temperature))
}
if c.TopK != 0 {
gm.SetTopK(int32(c.TopK))
}
if c.TopP != 0 {
gm.SetTopP(float32(c.TopP))
}
if len(c.SafetySettings) > 0 {
gm.SafetySettings = convertSafetySettings(c.SafetySettings)
}
}

// extractConfigFromInput converts any supported config type to GoogleAIConfig
func extractConfigFromInput(input *ai.ModelRequest) (GenerationGoogleAIConfig, error) {
var result GenerationGoogleAIConfig
switch config := input.Config.(type) {
case GenerationGoogleAIConfig:
return config, nil
case *GenerationGoogleAIConfig:
return *config, nil
case ai.GenerationCommonConfig:
result.MaxOutputTokens = config.MaxOutputTokens
result.StopSequences = config.StopSequences
result.Temperature = config.Temperature
result.TopK = config.TopK
result.TopP = config.TopP
result.Version = config.Version
return result, nil
case *ai.GenerationCommonConfig:
if config != nil {
result.MaxOutputTokens = config.MaxOutputTokens
result.StopSequences = config.StopSequences
result.Temperature = config.Temperature
result.TopK = config.TopK
result.TopP = config.TopP
result.Version = config.Version
}
return result, nil
case map[string]any:
// Todo: this will silently fail if extra parameters are passed, may want to expose errors
if err := mapToStruct(config, &result); err == nil {
return result, nil
} else {
return result, err
}
case nil:
// Empty but valid config
return result, nil
default:
return result, fmt.Errorf("unexpected config type: %T", input.Config)
}
}

func newModel(client *genai.Client, model string, input *ai.ModelRequest) (*genai.GenerativeModel, error) {
c, err := extractConfigFromInput(input)
if err != nil {
return nil, err
}

specifiedModel := model
if c.Version != "" {
specifiedModel = c.Version
}
gm := client.GenerativeModel(specifiedModel)
gm.SetCandidateCount(1)
applyGenerationConfig(gm, c)
for _, m := range input.Messages {
systemParts, err := convertParts(m.Content)
if err != nil {
Expand Down Expand Up @@ -658,3 +793,19 @@ func convertPart(p *ai.Part) (genai.Part, error) {
}

//copy:stop

// convertSafetySettings converts local SafetySetting to genai.SafetySetting
func convertSafetySettings(settings []*SafetySetting) []*genai.SafetySetting {
if len(settings) == 0 {
return nil
}

result := make([]*genai.SafetySetting, len(settings))
for i, s := range settings {
result[i] = &genai.SafetySetting{
Category: genai.HarmCategory(s.Category),
Threshold: genai.HarmBlockThreshold(s.Threshold),
}
}
return result
}
15 changes: 12 additions & 3 deletions go/samples/basic-gemini/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,18 @@ func main() {

resp, err := genkit.Generate(ctx, g,
ai.WithModel(m),
ai.WithConfig(&ai.GenerationCommonConfig{
Temperature: 1,
Version: "gemini-2.0-flash-001",
ai.WithConfig(&googleai.GenerationGoogleAIConfig{
GenerationCommonConfig: ai.GenerationCommonConfig{
Temperature: 1.0,
MaxOutputTokens: 256,
},
// Set custom safety settings - reduce restriction on harmfulness
SafetySettings: []*googleai.SafetySetting{
{
Category: googleai.HarmCategoryHarassment,
Threshold: googleai.HarmBlockMediumAndAbove,
},
},
}),
ai.WithTextPrompt(fmt.Sprintf(`Tell silly short jokes about %s`, input)))
if err != nil {
Expand Down
Loading