Skip to content

Commit

Permalink
Merge pull request #13 from symflower/dynamic-langugage-name-llm-prompt
Browse files Browse the repository at this point in the history
Dynamic language name in LLM prompt
  • Loading branch information
zimmski authored Apr 3, 2024
2 parents 80b33ac + 534fc0b commit 25d6b53
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 10 deletions.
2 changes: 1 addition & 1 deletion evaluate/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func EvaluateRepository(model model.Model, language language.Language, repositor

for _, filePath := range filePaths {
metrics.Total++
if err := model.GenerateTestsForFile(temporaryRepositoryPath, filePath); err != nil {
if err := model.GenerateTestsForFile(language, temporaryRepositoryPath, filePath); err != nil {
problems = append(problems, pkgerrors.WithMessage(err, filePath))
metrics.Problems++

Expand Down
5 changes: 5 additions & 0 deletions language/golang.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func (language *LanguageGolang) ID() (id string) {
return "golang"
}

// Name is the prose name of this language.
func (language *LanguageGolang) Name() (id string) {
return "Go"
}

// Files returns a list of relative file paths of the repository that should be evaluated.
func (language *LanguageGolang) Files(repositoryPath string) (filePaths []string, err error) {
repositoryPath, err = filepath.Abs(repositoryPath)
Expand Down
2 changes: 2 additions & 0 deletions language/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
type Language interface {
// ID returns the unique ID of this language.
ID() (id string)
// Name is the prose name of this language.
Name() (id string)

// Files returns a list of relative file paths of the repository that should be evaluated.
Files(repositoryPath string) (filePaths []string, err error)
Expand Down
12 changes: 9 additions & 3 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
pkgerrors "github.com/pkg/errors"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-symflower-codegen-testing/language"
"github.com/symflower/eval-symflower-codegen-testing/model"
"github.com/symflower/eval-symflower-codegen-testing/model/llm/prompt"
"github.com/symflower/eval-symflower-codegen-testing/provider"
Expand All @@ -34,6 +35,9 @@ func NewLLMModel(provider provider.QueryProvider, modelIdentifier string) model.

// llmGenerateTestForFilePromptContext is the context for template for generating an LLM test generation prompt.
type llmGenerateTestForFilePromptContext struct {
// Language holds the programming language name.
Language language.Language

// Code holds the source code of the file.
Code string
// FilePath holds the file path of the file.
Expand All @@ -44,11 +48,11 @@ type llmGenerateTestForFilePromptContext struct {

// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(`
Given the following Go code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code.
Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code.
The tests should produce 100 percent code coverage and must compile.
The response must contain only the test code and nothing else.
` + "```" + `
` + "```" + `{{ .Language.ID }}
{{ .Code }}
` + "```" + `
`)))
Expand All @@ -73,7 +77,7 @@ func (m *llm) ID() (id string) {
}

// GenerateTestsForFile generates test files for the given implementation file in a repository.
func (m *llm) GenerateTestsForFile(repositoryPath string, filePath string) (err error) {
func (m *llm) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error) {
data, err := os.ReadFile(filepath.Join(repositoryPath, filePath))
if err != nil {
return err
Expand All @@ -83,6 +87,8 @@ func (m *llm) GenerateTestsForFile(repositoryPath string, filePath string) (err
importPath := filepath.Join(filepath.Base(repositoryPath), filepath.Dir(filePath))

message, err := llmGenerateTestForFilePrompt(&llmGenerateTestForFilePromptContext{
Language: language,

Code: fileContent,
FilePath: filePath,
ImportPath: importPath,
Expand Down
11 changes: 8 additions & 3 deletions model/llm/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/zimmski/osutil/bytesutil"

"github.com/symflower/eval-symflower-codegen-testing/language"
providertesting "github.com/symflower/eval-symflower-codegen-testing/provider/testing"
)

Expand All @@ -20,9 +21,10 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {

SetupMock func(mockedProvider *providertesting.MockQueryProvider)

Language language.Language
ModelID string
SourceFileContent string
SourceFilePath string
ModelID string

ExpectedTestFileContent string
ExpectedTestFilePath string
Expand All @@ -40,7 +42,7 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
tc.SetupMock(mock)
llm := NewLLMModel(mock, tc.ModelID)

assert.NoError(t, llm.GenerateTestsForFile(temporaryPath, tc.SourceFilePath))
assert.NoError(t, llm.GenerateTestsForFile(tc.Language, temporaryPath, tc.SourceFilePath))

actualTestFileContent, err := os.ReadFile(filepath.Join(temporaryPath, tc.ExpectedTestFilePath))
assert.NoError(t, err)
Expand All @@ -56,6 +58,8 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
`
sourceFilePath := "simple.go"
promptMessage, err := llmGenerateTestForFilePrompt(&llmGenerateTestForFilePromptContext{
Language: &language.LanguageGolang{},

Code: bytesutil.StringTrimIndentations(sourceFileContent),
FilePath: sourceFilePath,
ImportPath: "native",
Expand All @@ -74,9 +78,10 @@ func TestModelLLMGenerateTestsForFile(t *testing.T) {
`), nil)
},

Language: &language.LanguageGolang{},
ModelID: "model-id",
SourceFileContent: sourceFileContent,
SourceFilePath: sourceFilePath,
ModelID: "model-id",

ExpectedTestFileContent: `
package native
Expand Down
4 changes: 3 additions & 1 deletion model/model.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package model

import "github.com/symflower/eval-symflower-codegen-testing/language"

// Model defines a model that can be queried for generations.
type Model interface {
// ID returns the unique ID of this model.
ID() (id string)

// GenerateTestsForFile generates test files for the given implementation file in a repository.
GenerateTestsForFile(repositoryPath string, filePath string) (err error)
GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error)
}
3 changes: 2 additions & 1 deletion model/symflower/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package symflower
import (
pkgerrors "github.com/pkg/errors"

"github.com/symflower/eval-symflower-codegen-testing/language"
"github.com/symflower/eval-symflower-codegen-testing/model"
"github.com/symflower/eval-symflower-codegen-testing/provider"
"github.com/symflower/eval-symflower-codegen-testing/util"
Expand All @@ -19,7 +20,7 @@ func (m *ModelSymflower) ID() (id string) {
}

// GenerateTestsForFile generates test files for the given implementation file in a repository.
func (m *ModelSymflower) GenerateTestsForFile(repositoryPath string, filePath string) (err error) {
func (m *ModelSymflower) GenerateTestsForFile(language language.Language, repositoryPath string, filePath string) (err error) {
_, _, err = util.CommandWithResult(&util.Command{
Command: []string{
"symflower", "unit-tests",
Expand Down
2 changes: 1 addition & 1 deletion model/symflower/symflower_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestModelSymflowerGenerateTestsForFile(t *testing.T) {
if tc.ModelSymflower == nil {
tc.ModelSymflower = &ModelSymflower{}
}
actualErr := tc.ModelSymflower.GenerateTestsForFile(repositoryPath, tc.FilePath)
actualErr := tc.ModelSymflower.GenerateTestsForFile(tc.Language, repositoryPath, tc.FilePath)

if tc.ExpectedError != nil {
assert.ErrorIs(t, tc.ExpectedError, actualErr)
Expand Down

0 comments on commit 25d6b53

Please sign in to comment.