From 3291eac67eb3e2dc5e87943bae7c56744573b6f0 Mon Sep 17 00:00:00 2001 From: Andreas Humenberger Date: Wed, 19 Feb 2025 15:20:06 +0100 Subject: [PATCH] fix, Allow selecting models with attributes --- cmd/eval-dev-quality/cmd/evaluate.go | 21 +++++++++---- cmd/eval-dev-quality/cmd/evaluate_test.go | 38 +++++++++++++++++++++++ model/llm/llm.go | 19 +++++++++++- model/model.go | 5 +++ model/symflower/symflower.go | 11 +++++++ model/testing/Model_mock_gen.go | 25 +++++++++++++++ 6 files changed, 112 insertions(+), 7 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index fa756f9e..b3d31765 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -447,7 +447,8 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", modelIDsWithProviderAndAttributes, provider.ProviderModelSeparator) } - modelID, _ := model.ParseModelID(modelIDsWithAttributes) + modelID, attributes := model.ParseModelID(modelIDsWithAttributes) + modelIDWithProvider := providerID + provider.ProviderModelSeparator + modelID p, ok := providers[providerID] if !ok { @@ -460,18 +461,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. } // TODO If a model has not been pulled before, it was not available for at least the "Ollama" provider. Make this cleaner, we should not rebuild every time. - if _, ok := models[modelIDsWithProviderAndAttributes]; !ok { + if _, ok := models[modelIDWithProvider]; !ok { ms, err := p.Models() if err != nil { command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err) } for _, m := range ms { - if _, ok := models[m.ID()]; ok { + if _, ok := models[m.ModelID()]; ok { continue } - models[m.ID()] = m - evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ID()) + models[m.ModelID()] = m + evaluationConfiguration.Models.Available = append(evaluationConfiguration.Models.Available, m.ModelID()) } modelIDs = maps.Keys(models) sort.Strings(modelIDs) @@ -489,10 +490,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. pc.AddModel(m) } else { var ok bool - m, ok = models[modelIDsWithProviderAndAttributes] + m, ok = models[modelIDWithProvider] if !ok { command.logger.Panicf("ERROR: model %q does not exist for provider %q. Valid models are: %s", modelIDsWithProviderAndAttributes, providerID, strings.Join(modelIDs, ", ")) } + + // If a model with attributes is requested, we add the base model plus attributes as new model to our list. + if len(attributes) > 0 { + modelWithAttributes := m.Clone() + modelWithAttributes.SetAttributes(attributes) + models[modelWithAttributes.ID()] = modelWithAttributes + m = modelWithAttributes + } } evaluationContext.Models = append(evaluationContext.Models, m) evaluationContext.ProviderForModel[m] = p diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index f08424a4..874340b1 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -1478,6 +1478,44 @@ func TestEvaluateInitialize(t *testing.T) { }, config.Repositories.Selected) }, }) + validate(t, &testCase{ + Name: "Model with attributes", + + Command: makeValidCommand(func(command *Evaluate) { + command.ModelIDsWithProviderAndAttributes = []string{ + "openrouter/openai/o3-mini@reasoning_effort=low", + "openrouter/openai/o3-mini@reasoning_effort=high", + } + command.ProviderTokens = map[string]string{ + "openrouter": "fake-token", + } + }), + + ValidateContext: func(t *testing.T, context *evaluate.Context) { + assert.Len(t, context.Models, 2) + + assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=high", context.Models[0].ID()) + assert.Equal(t, "openrouter/openai/o3-mini", context.Models[0].ModelID()) + expectedAttributes := map[string]string{ + "reasoning_effort": "high", + } + assert.Equal(t, expectedAttributes, context.Models[0].Attributes()) + + assert.Equal(t, "openrouter/openai/o3-mini@reasoning_effort=low", context.Models[1].ID()) + assert.Equal(t, "openrouter/openai/o3-mini", context.Models[1].ModelID()) + expectedAttributes = map[string]string{ + "reasoning_effort": "low", + } + assert.Equal(t, expectedAttributes, context.Models[1].Attributes()) + }, + ValidateConfiguration: func(t *testing.T, config *EvaluationConfiguration) { + expectedSelected := []string{ + "openrouter/openai/o3-mini@reasoning_effort=high", + "openrouter/openai/o3-mini@reasoning_effort=low", + } + assert.Equal(t, expectedSelected, config.Models.Selected) + }, + }) validate(t, &testCase{ Name: "Local runtime does not allow parallel parameter", diff --git a/model/llm/llm.go b/model/llm/llm.go index 1f99d2cf..25f18895 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -70,7 +70,12 @@ var _ model.Model = (*Model)(nil) // ID returns full identifier, including the provider and attributes. func (m *Model) ID() (id string) { - return m.id + attributeString := "" + for key, value := range m.attributes { + attributeString += "@" + key + "=" + value + } + + return m.id + attributeString } // ModelID returns the unique identifier of this model with its provider. @@ -93,11 +98,23 @@ func (m *Model) Attributes() (attributes map[string]string) { return m.attributes } +// SetAttributes sets the given attributes. +func (m *Model) SetAttributes(attributes map[string]string) { + m.attributes = attributes +} + // MetaInformation returns the meta information of a model. func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return m.metaInformation } +// Clone returns a copy of the model. +func (m *Model) Clone() (clone model.Model) { + model := *m + + return &model +} + // llmSourceFilePromptContext is the base template context for an LLM generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. diff --git a/model/model.go b/model/model.go index 15d0c2d8..34984306 100644 --- a/model/model.go +++ b/model/model.go @@ -18,9 +18,14 @@ type Model interface { // Attributes returns query attributes. Attributes() (attributes map[string]string) + // SetAttributes sets the given attributes. + SetAttributes(attributes map[string]string) // MetaInformation returns the meta information of a model. MetaInformation() *MetaInformation + + // Clone returns a copy of the model. + Clone() (clone Model) } // ParseModelID takes a packaged model ID with optional attributes and converts it into its model ID and optional attributes. diff --git a/model/symflower/symflower.go b/model/symflower/symflower.go index 5093b91f..e7b2443b 100644 --- a/model/symflower/symflower.go +++ b/model/symflower/symflower.go @@ -82,11 +82,22 @@ func (m *Model) Attributes() (attributes map[string]string) { return nil } +// SetAttributes sets the given attributes. +func (m *Model) SetAttributes(attributes map[string]string) { +} + // MetaInformation returns the meta information of a model. func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return nil } +// Clone returns a copy of the model. +func (m *Model) Clone() (clone model.Model) { + model := *m + + return &model +} + var _ model.CapabilityWriteTests = (*Model)(nil) // WriteTests generates test files for the given implementation file in a repository. diff --git a/model/testing/Model_mock_gen.go b/model/testing/Model_mock_gen.go index 7d07fbe3..bc9a297a 100644 --- a/model/testing/Model_mock_gen.go +++ b/model/testing/Model_mock_gen.go @@ -32,6 +32,26 @@ func (_m *MockModel) Attributes() map[string]string { return r0 } +// Clone provides a mock function with given fields: +func (_m *MockModel) Clone() model.Model { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Clone") + } + + var r0 model.Model + if rf, ok := ret.Get(0).(func() model.Model); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(model.Model) + } + } + + return r0 +} + // ID provides a mock function with given fields: func (_m *MockModel) ID() string { ret := _m.Called() @@ -106,6 +126,11 @@ func (_m *MockModel) ModelIDWithoutProvider() string { return r0 } +// SetAttributes provides a mock function with given fields: attributes +func (_m *MockModel) SetAttributes(attributes map[string]string) { + _m.Called(attributes) +} + // NewMockModel creates a new instance of MockModel. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModel(t interface {