Skip to content

Commit

Permalink
prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Gekko0114 committed Jul 8, 2024
1 parent 2ddac23 commit 3fa0be4
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pkg/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func main() {
{Role: "user", Content: "こんにちは"},
}

response, err := llm.SendMessage(msg)
response, err := llm.Invoke(msg)
if err != nil {
fmt.Println(err)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewChatOpenAI(model string) (*ChatOpenAI, error) {
}, nil
}

func (c *ChatOpenAI) SendMessage(message []Message) (*Response, error) {
func (c *ChatOpenAI) Invoke(message []Message) (*Response, error) {

requestBody := Request{
Model: c.model,
Expand Down
4 changes: 2 additions & 2 deletions pkg/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/jarcoal/httpmock"
)

func TestSendMessage(t *testing.T) {
func TestInvoke(t *testing.T) {
tests := []struct {
name string
message []Message
Expand Down Expand Up @@ -50,7 +50,7 @@ func TestSendMessage(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
have, err := client.SendMessage(tc.message)
have, err := client.Invoke(tc.message)
if err != nil {
t.Fatalf("Error happens: %v", err)
}
Expand Down
30 changes: 30 additions & 0 deletions pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package prompt

import (
"bytes"
"fmt"
"text/template"
)

type PromptTemplate struct {
template *template.Template
}

func NewPromptTemplate(input string) (*PromptTemplate, error) {
tmpl, err := template.New("tmpl").Parse(input)
if err != nil {
return nil, fmt.Errorf("failed to parse template: %w", err)
}

return &PromptTemplate{
template: tmpl,
}, nil
}

func (t *PromptTemplate) Invoke(input any) (string, error) {
var buf bytes.Buffer
if err := t.template.Execute(&buf, input); err != nil {
return "", err
}
return buf.String(), nil
}
43 changes: 43 additions & 0 deletions pkg/prompt/prompt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package prompt

import (
"testing"
)

func TestInvoke(t *testing.T) {
tests := []struct {
name string
template string
input any
expected string
}{
{
name: "template is string",
template: "template",
expected: "template",
},
{
name: "template includes input",
template: "the meaning of {{.Word}}?",
input: map[string]string{
"Word": "satisfaction",
},
expected: "the meaning of satisfaction?",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
prompt, err := NewPromptTemplate(tc.template)
if err != nil {
t.Fatalf("Error happens: %v", err)
}
have, err := prompt.Invoke(tc.input)
if err != nil {
t.Fatalf("Error happens: %v", err)
}
if have != tc.expected {
t.Fatalf("unexpected string: %v != %v", have, tc.expected)
}
})
}
}

0 comments on commit 3fa0be4

Please sign in to comment.