-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconversation_manager.go
104 lines (86 loc) · 2.69 KB
/
conversation_manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package main
import (
"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)
type Conversation struct {
Messages []azopenai.ChatRequestMessageClassification
SystemMessage string
}
type ConversationManager struct {
conversations map[int64]*Conversation
drawPrompts map[int64]string
pastMessagesIncluded int
}
const defaultSystemMessage = "You are a helpful assistant."
func NewConversationManager(pastMessagesIncluded int) *ConversationManager {
return &ConversationManager{
conversations: make(map[int64]*Conversation),
drawPrompts: make(map[int64]string),
pastMessagesIncluded: pastMessagesIncluded,
}
}
func (c *ConversationManager) GetConversation(id int64) *Conversation {
if conversation, ok := c.conversations[id]; ok {
return conversation
}
return c.ResetAll(id)
}
func (c *ConversationManager) Reset(id int64) *Conversation {
conv := startConversation(c.getSystemMessage(id))
c.conversations[id] = &conv
return &conv
}
func (c *ConversationManager) ResetAll(id int64) *Conversation {
delete(c.conversations, id)
return c.Reset(id)
}
func (c *ConversationManager) SetSystemMessage(id int64, systemMessage string) *Conversation {
c.conversations[id].SystemMessage = systemMessage
return c.Reset(id)
}
func (c *ConversationManager) getSystemMessage(id int64) string {
if conv, ok := c.conversations[id]; ok {
return conv.SystemMessage
}
return defaultSystemMessage
}
func startConversation(systemMessage string) Conversation {
return Conversation{
Messages: []azopenai.ChatRequestMessageClassification{
&azopenai.ChatRequestSystemMessage{
Content: to.Ptr(systemMessage),
},
},
SystemMessage: systemMessage,
}
}
func (c *ConversationManager) AddUserMessage(id int64, userInput string) *Conversation {
conv := c.GetConversation(id)
conv.Messages = append(
conv.Messages,
&azopenai.ChatRequestUserMessage{
Content: azopenai.NewChatRequestUserMessageContent(userInput),
},
)
return conv
}
func (c *ConversationManager) AddResponse(id int64, response string) {
conv := c.GetConversation(id)
conv.Messages = append(
conv.Messages,
&azopenai.ChatRequestAssistantMessage{
Content: to.Ptr(response),
},
)
if len(conv.Messages) > c.pastMessagesIncluded && len(conv.Messages) > 3 {
// keep the system message, remove 2nd (user message) and 3rd (assistant response)
conv.Messages = append([]azopenai.ChatRequestMessageClassification{conv.Messages[0]}, conv.Messages[3:]...)
}
}
func (c *ConversationManager) GetLastDrawPrompt(id int64) string {
return c.drawPrompts[id]
}
func (c *ConversationManager) SetLastDrawPrompt(id int64, prompt string) {
c.drawPrompts[id] = prompt
}