Skip to content

Commit

Permalink
update signal and form type: ping, restart, stop, chat
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Oct 21, 2023
1 parent 7696ec3 commit f744404
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 56 deletions.
38 changes: 35 additions & 3 deletions manager/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"strings"
)

const defaultMessage = "Sorry, I don't understand. Please try again."
Expand All @@ -25,7 +26,34 @@ func CollectQuota(c *gin.Context, user *auth.User, quota float32, reversible boo
}
}

func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation.Conversation) string {
func ImageHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
if user == nil {
conn.Send(globals.ChatSegmentResponse{
Message: "You need to login to use this feature.",
End: true,
})
return "You need to login to use this feature."
}

prompt := strings.TrimSpace(strings.TrimPrefix(instance.GetLatestMessage(), "/image"))

if response, err := GenerateImage(conn.GetCtx(), user, prompt); err != nil {
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
End: true,
})
return err.Error()
} else {
conn.Send(globals.ChatSegmentResponse{
Quota: 1.,
Message: response,
End: true,
})
return response
}
}

func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
defer func() {
if err := recover(); err != nil {
fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
Expand Down Expand Up @@ -70,12 +98,16 @@ func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation.
Message: segment,
Reversible: reversible && globals.IsGPT4Model(model),
}, func(data string) error {
return conn.SendJSON(globals.ChatSegmentResponse{
if signal := conn.PeekWithType(StopType); signal != nil {
// stop signal from client
return fmt.Errorf("signal")
}
return conn.SendClient(globals.ChatSegmentResponse{
Message: buffer.Write(data),
Quota: buffer.GetQuota(),
End: false,
})
}); err != nil {
}); err != nil && err.Error() != "signal" {
CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible)
conn.Send(globals.ChatSegmentResponse{
Message: err.Error(),
Expand Down
177 changes: 177 additions & 0 deletions manager/connection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package manager

import (
"chat/globals"
"chat/manager/conversation"
"chat/utils"
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
)

const (
ChatType = "chat"
StopType = "stop"
RestartType = "restart"
ShareType = "share"
)

type Stack []*conversation.FormMessage

type Connection struct {
conn *utils.WebSocket
stack Stack
auth bool
hash string
}

func NewConnection(conn *utils.WebSocket, auth bool, hash string, bufferSize int) *Connection {
buf := &Connection{
conn: conn,
auth: auth,
hash: hash,
stack: make(Stack, bufferSize),
}
buf.ReadWorker()
return buf
}

func (c *Connection) GetConn() *utils.WebSocket {
return c.conn
}

func (c *Connection) GetCtx() *gin.Context {
return c.conn.GetCtx()
}

func (c *Connection) GetStack() Stack {
return c.stack
}

func (c *Connection) ReadWorker() {
go func() {
for {
form := utils.ReadForm[conversation.FormMessage](c.conn)
if form.Type == "" {
form.Type = ChatType
}

c.Write(form)

if form == nil {
return
}
}
}()
}

func (c *Connection) Write(data *conversation.FormMessage) {
c.stack = append(c.stack, data)
}

func (c *Connection) Read() (*conversation.FormMessage, bool) {
if len(c.stack) == 0 {
return nil, false
}
form := c.stack[0]
c.Skip()
return form, true
}

func (c *Connection) ReadWithBlock() *conversation.FormMessage {
// return: nil if connection is closed
for {
if form, ok := c.Read(); ok {
return form
}
}
}

func (c *Connection) Peek() *conversation.FormMessage {
// return nil if no message is received
if len(c.stack) == 0 {
return nil
}
return c.ReadWithBlock()
}

func (c *Connection) PeekWithType(t string) *conversation.FormMessage {
// skip if type is matched

if form := c.Peek(); form != nil {
if form.Type == t {
c.Skip()
return form
}
}

return nil
}

func (c *Connection) Skip() {
if len(c.stack) == 0 {
return
}
c.stack = c.stack[1:]
}

func (c *Connection) GetDB() *sql.DB {
return c.conn.GetDB()
}

func (c *Connection) GetCache() *redis.Client {
return c.conn.GetCache()
}

func (c *Connection) Send(message globals.ChatSegmentResponse) {
c.conn.Send(message)
}

func (c *Connection) SendClient(message globals.ChatSegmentResponse) error {
return c.conn.SendJSON(message)
}

func (c *Connection) Handle(handler func(*conversation.FormMessage) error) {
defer c.conn.DeferClose()

for {
form := c.ReadWithBlock()
if form == nil {
return
}

if !c.Lock() {
return
}

if err := handler(form); err != nil {
return
}

c.Release()
}
}

func (c *Connection) Lock() bool {
state := c.conn.IncrRateWithLimit(
c.hash,
utils.Multi[int64](c.auth, globals.ChatMaxThread, globals.AnonymousMaxThread),
60,
)

if !state {
c.conn.Send(globals.ChatSegmentResponse{
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread),
End: true,
})

return false
}

return true
}

func (c *Connection) Release() {
c.conn.DecrRate(c.hash)
}
27 changes: 25 additions & 2 deletions manager/conversation/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ type Conversation struct {
Message []globals.Message `json:"message"`
Model string `json:"model"`
EnableWeb bool `json:"enable_web"`
Shared bool `json:"shared"`
}

type FormMessage struct {
Type string `json:"type"` // ping
Type string `json:"type"`
Message string `json:"message"`
Web bool `json:"web"`
Model string `json:"model"`
Expand Down Expand Up @@ -117,7 +118,7 @@ func (c *Conversation) GetMessageById(id int) globals.Message {
return c.Message[id]
}

func (c *Conversation) GetMessageSize() int {
func (c *Conversation) GetMessageLength() int {
return len(c.Message)
}

Expand All @@ -140,6 +141,18 @@ func (c *Conversation) AddMessage(message globals.Message) {
c.Message = append(c.Message, message)
}

func (c *Conversation) AddMessages(messages []globals.Message) {
c.Message = append(c.Message, messages...)
}

func (c *Conversation) InsertMessage(message globals.Message, index int) {
c.Message = append(c.Message[:index], append([]globals.Message{message}, c.Message[index:]...)...)
}

func (c *Conversation) InsertMessages(messages []globals.Message, index int) {
c.Message = append(c.Message[:index], append(messages, c.Message[index:]...)...)
}

func (c *Conversation) AddMessageFromUser(message string) {
c.AddMessage(globals.Message{
Role: "user",
Expand Down Expand Up @@ -231,3 +244,13 @@ func (c *Conversation) SaveResponse(db *sql.DB, message string) {
c.AddMessageFromAssistant(message)
c.SaveConversation(db)
}

func (c *Conversation) RemoveMessage(index int) globals.Message {
message := c.Message[index]
c.Message = append(c.Message[:index], c.Message[index+1:]...)
return message
}

func (c *Conversation) RemoveLatestMessage() globals.Message {
return c.RemoveMessage(len(c.Message) - 1)
}
7 changes: 4 additions & 3 deletions manager/conversation/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func UseSharedConversation(db *sql.DB, user *auth.User, hash string) *Conversati
}

func (c *Conversation) LoadSharing(db *sql.DB, hash string) {
if strings.TrimSpace(hash) == "" {
if strings.TrimSpace(hash) == "" || c.Shared == true {
return
}

Expand All @@ -139,6 +139,7 @@ func (c *Conversation) LoadSharing(db *sql.DB, hash string) {
return
}

c.Message = shared.Messages
c.Name = shared.Name
c.InsertMessages(shared.Messages, 0)
c.SetName(db, shared.Name)
c.Shared = true
}
Loading

0 comments on commit f744404

Please sign in to comment.