Skip to content

Commit

Permalink
Assessment for arbitrary response properties of test generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bauersimon committed Apr 16, 2024
1 parent a648286 commit 9e4f513
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 30 deletions.
40 changes: 40 additions & 0 deletions evaluate/metrics/assessment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package metrics

// AssessmentKey is a description for a numerical assessment value.
type AssessmentKey string

// allAssessmentKeys holds all assessment keys.
var allAssessmentKeys []AssessmentKey

func registerAssessmentKey(key string) AssessmentKey {
assessment := AssessmentKey(key)
allAssessmentKeys = append(allAssessmentKeys, assessment)

return assessment
}

var (
// AssessmentKeyNoExplanations means that a model did not produce additional explanations.
AssessmentKeyNoExplanations = registerAssessmentKey("no-explanations")
)

// Assessments holds numerical assessment metrics.
type Assessments map[AssessmentKey]uint

// Add sums two assessment objects.
func (a Assessments) Add(o Assessments) Assessments {
if a == nil {
a = Assessments{}
}
if o == nil {
o = Assessments{}
}

assessments := map[AssessmentKey]uint{}

for _, k := range allAssessmentKeys {
assessments[k] = a[k] + o[k]
}

return Assessments(assessments)
}
13 changes: 12 additions & 1 deletion evaluate/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ type Metrics struct {

// Coverage holds the coverage of the benchmarking candidates.
Coverage []float64

// Assessments holds numerical assessments of a generation.
Assessments Assessments
}

// Add sums two metrics objects.
Expand All @@ -33,6 +36,8 @@ func (m Metrics) Add(o Metrics) Metrics {
Total: m.Total + o.Total,

Coverage: append(m.Coverage, o.Coverage...),

Assessments: m.Assessments.Add(o.Assessments),
}
}

Expand Down Expand Up @@ -70,11 +75,17 @@ func (m Metrics) String() string {

// StringCSV returns a CSV row string representation of the metrics.
func (m Metrics) StringCSV() []string {
assessment := m.Assessments
if assessment == nil {
assessment = Assessments{}
}

return []string{
fmt.Sprintf("%d", m.Total),
fmt.Sprintf("%d", m.Executed),
fmt.Sprintf("%d", m.Problems),
fmt.Sprintf("%.0f", m.AverageCoverage()),
fmt.Sprintf("%d", assessment[AssessmentKeyNoExplanations]),
}
}

Expand All @@ -83,7 +94,7 @@ func FormatStringCSV(metricsPerModel map[string]Metrics) (string, error) {
var out strings.Builder
csv := csv.NewWriter(&out)

if err := csv.Write([]string{"model", "files-total", "files-executed", "files-problems", "coverage-statement"}); err != nil {
if err := csv.Write([]string{"model", "files-total", "files-executed", "files-problems", "coverage-statement", "no-explanation"}); err != nil {
return "", err
}
categories := maps.Keys(metricsPerModel)
Expand Down
16 changes: 11 additions & 5 deletions evaluate/metrics/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func TestFormatStringCSV(t *testing.T) {
},

ExpectedString: `
model,files-total,files-executed,files-problems,coverage-statement
Model,0,0,0,0
model,files-total,files-executed,files-problems,coverage-statement,no-explanation
Model,0,0,0,0,0
`,
})
validate(t, &testCase{
Expand All @@ -46,19 +46,25 @@ func TestFormatStringCSV(t *testing.T) {
Executed: 3,
Problems: 2,
Coverage: []float64{100.0},
Assessments: Assessments{
AssessmentKeyNoExplanations: 3,
},
},
"ModelB": Metrics{
Total: 4,
Executed: 2,
Problems: 2,
Coverage: []float64{70.0},
Assessments: Assessments{
AssessmentKeyNoExplanations: 2,
},
},
},

ExpectedString: `
model,files-total,files-executed,files-problems,coverage-statement
ModelA,5,3,2,100
ModelB,4,2,2,70
model,files-total,files-executed,files-problems,coverage-statement,no-explanation
ModelA,5,3,2,100,3
ModelB,4,2,2,70,2
`,
})
}
10 changes: 10 additions & 0 deletions free.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

make install && eval-dev-quality evaluate \
--model symflower/symbolic-execution \
--model openrouter/openrouter/cinematika-7b:free \
--model openrouter/google/gemma-7b-it:free \
--model openrouter/gryphe/mythomist-7b:free \
--model openrouter/mistralai/mistral-7b-instruct:free \
--model openrouter/nousresearch/nous-capybara-7b:free \
--model openrouter/undi95/toppy-m-7b:free
17 changes: 11 additions & 6 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
pkgerrors "github.com/pkg/errors"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/model/llm/prompt"
Expand Down Expand Up @@ -77,10 +78,10 @@ func (m *llm) ID() (id string) {
}

// GenerateTestsForFile generates test files for the given implementation file in a repository.
func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error) {
func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (assessment metrics.Assessments, err error) {
data, err := os.ReadFile(filepath.Join(repositoryPath, filePath))
if err != nil {
return err
return nil, err
}
fileContent := strings.TrimSpace(string(data))

Expand All @@ -94,20 +95,24 @@ func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath st
ImportPath: importPath,
})
if err != nil {
return err
return nil, err
}

response, err := m.provider.Query(context.Background(), m.model, request)
if err != nil {
return err
return nil, err
}
log.Printf("Model %q responded to query %s with: %s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t"))), string(bytesutil.PrefixLines([]byte(response), []byte("\t"))))

testContent := prompt.ParseResponse(response)
assessment, testContent := prompt.ParseResponse(response)

// TODO Ask the model for the test file name or compute it in a more sophisticated manner.
fileExtension := filepath.Ext(filePath)
testFilePath := filepath.Join(repositoryPath, strings.TrimSuffix(filePath, fileExtension)+"_test"+fileExtension)

return os.WriteFile(testFilePath, []byte(testContent), 0644)
if err := os.WriteFile(testFilePath, []byte(testContent), 0644); err != nil {
return nil, pkgerrors.WithStack(err)
}

return assessment, nil
}
3 changes: 2 additions & 1 deletion model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
tc.SetupMock(mock)
llm := NewLLMModel(mock, tc.ModelID)

assert.NoError(t, llm.GenerateTestsForFile(tc.Language, temporaryPath, tc.SourceFilePath))
_, actualErr := llm.GenerateTestsForFile(tc.Language, temporaryPath, tc.SourceFilePath)
assert.NoError(t, actualErr)

actualTestFileContent, err := os.ReadFile(filepath.Join(temporaryPath, tc.ExpectedTestFilePath))
assert.NoError(t, err)
Expand Down
18 changes: 15 additions & 3 deletions model/llm/prompt/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"regexp"
"strings"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/zimmski/osutil/bytesutil"
)

Expand All @@ -13,16 +14,27 @@ var (
)

// ParseResponse parses code from a model's response.
func ParseResponse(response string) (code string) {
func ParseResponse(response string) (assessment metrics.Assessments, code string) {
assessment = metrics.Assessments{}

// Some models produce duplicated code tags, so unify them if needed.
response = codeTagDuplicatedRe.ReplaceAllString(response, "```")

blocks := bytesutil.GuardedBlocks(response, codeTagRe, codeTagRe)
if len(blocks) == 0 { // When no code blocks are found, assume that just the code is returned.
return strings.TrimSpace(response)
assessment[metrics.AssessmentKeyNoExplanations] = 1

return assessment, strings.TrimSpace(response)
}
// Assume the first code block contains the response code fragment.
block := blocks[0]

return strings.TrimSpace(codeTagRe.ReplaceAllString(block, ""))
responseWithoutBlock := strings.Replace(response, block, "", 1)
if len(strings.TrimSpace(responseWithoutBlock)) == 0 {
assessment[metrics.AssessmentKeyNoExplanations] = 1
} else {
assessment[metrics.AssessmentKeyNoExplanations] = 0
}

return assessment, strings.TrimSpace(codeTagRe.ReplaceAllString(block, ""))
}
41 changes: 33 additions & 8 deletions model/llm/prompt/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/zimmski/osutil/bytesutil"
)

Expand All @@ -14,13 +15,15 @@ func TestParseResponse(t *testing.T) {

Response string

ExpectedCode string
ExpectedAssessment metrics.Assessments
ExpectedCode string
}

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
actualCode := ParseResponse(tc.Response)
actualAssessment, actualCode := ParseResponse(tc.Response)

assert.Equal(t, tc.ExpectedAssessment, actualAssessment)
assert.Equal(t, strings.TrimSpace(tc.ExpectedCode), actualCode)
})
}
Expand All @@ -38,44 +41,66 @@ func TestParseResponse(t *testing.T) {
validate(t, &testCase{
Name: "Only Code",

Response: code,
Response: code,

ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
},
ExpectedCode: code,
})

t.Run("Formatted Code", func(t *testing.T) {
validate(t, &testCase{
Name: "No Prose",

Response: "```\n" + code + "\n```\n",
Response: "```\n" + code + "\n```\n",

ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
},
ExpectedCode: code,
})

validate(t, &testCase{
Name: "With Prose",

Response: "Some text...\n\n```\n" + code + "\n```\n\nSome more text...",
Response: "Some text...\n\n```\n" + code + "\n```\n\nSome more text...",

ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 0,
},
ExpectedCode: code,
})
})

validate(t, &testCase{
Name: "Language Specified",

Response: "```go\n" + code + "\n```\n",
Response: "```go\n" + code + "\n```\n",

ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
},
ExpectedCode: code,
})

validate(t, &testCase{
Name: "Whitespace before Code Block Guards",

Response: " ```\n" + code + "\n\t```\n",
Response: " ```\n" + code + "\n\t```\n",
ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
},
ExpectedCode: code,
})

validate(t, &testCase{
Name: "Duplicated Code Block Guards",

Response: "```\n```\n" + code + "\n```\n```\n",
Response: "```\n```\n" + code + "\n```\n```\n",
ExpectedAssessment: metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
},
ExpectedCode: code,
})
}
7 changes: 5 additions & 2 deletions model/model.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package model

import "github.com/symflower/eval-dev-quality/language"
import (
"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
)

// Model defines a model that can be queried for generations.
type Model interface {
// ID returns the unique ID of this model.
ID() (id string)

// GenerateTestsForFile generates test files for the given implementation file in a repository.
GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error)
GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (assessments metrics.Assessments, err error)
}
9 changes: 6 additions & 3 deletions model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package symflower
import (
pkgerrors "github.com/pkg/errors"

"github.com/symflower/eval-dev-quality/evaluate/metrics"
"github.com/symflower/eval-dev-quality/language"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider"
Expand All @@ -20,7 +21,7 @@ func (m *ModelSymflower) ID() (id string) {
}

// GenerateTestsForFile generates test files for the given implementation file in a repository.
func (m *ModelSymflower) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error) {
func (m *ModelSymflower) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (assessment metrics.Assessments, err error) {
_, _, err = util.CommandWithResult(&util.Command{
Command: []string{
"symflower", "unit-tests",
Expand All @@ -31,8 +32,10 @@ func (m *ModelSymflower) GenerateTestsForFile(language language.Language, reposi
Directory: repositoryPath,
})
if err != nil {
return pkgerrors.WithStack(err)
return nil, pkgerrors.WithStack(err)
}

return nil
return metrics.Assessments{
metrics.AssessmentKeyNoExplanations: 1,
}, nil
}
2 changes: 1 addition & 1 deletion model/symflower/symflower_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestModelSymflowerGenerateTestsForFile(t *testing.T) {
if tc.ModelSymflower == nil {
tc.ModelSymflower = &ModelSymflower{}
}
actualErr := tc.ModelSymflower.GenerateTestsForFile(tc.Language, repositoryPath, tc.FilePath)
_, actualErr := tc.ModelSymflower.GenerateTestsForFile(tc.Language, repositoryPath, tc.FilePath)

if tc.ExpectedError != nil {
assert.ErrorIs(t, tc.ExpectedError, actualErr)
Expand Down

0 comments on commit 9e4f513

Please sign in to comment.