diff --git a/manager/chat.go b/manager/chat.go index 21f45825..48b31e40 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -9,6 +9,7 @@ import ( "chat/utils" "fmt" "github.com/gin-gonic/gin" + "strings" ) const defaultMessage = "Sorry, I don't understand. Please try again." @@ -25,7 +26,34 @@ func CollectQuota(c *gin.Context, user *auth.User, quota float32, reversible boo } } -func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation.Conversation) string { +func ImageHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { + if user == nil { + conn.Send(globals.ChatSegmentResponse{ + Message: "You need to login to use this feature.", + End: true, + }) + return "You need to login to use this feature." + } + + prompt := strings.TrimSpace(strings.TrimPrefix(instance.GetLatestMessage(), "/image")) + + if response, err := GenerateImage(conn.GetCtx(), user, prompt); err != nil { + conn.Send(globals.ChatSegmentResponse{ + Message: err.Error(), + End: true, + }) + return err.Error() + } else { + conn.Send(globals.ChatSegmentResponse{ + Quota: 1., + Message: response, + End: true, + }) + return response + } +} + +func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { defer func() { if err := recover(); err != nil { fmt.Println(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)", @@ -70,12 +98,16 @@ func ChatHandler(conn *utils.WebSocket, user *auth.User, instance *conversation. Message: segment, Reversible: reversible && globals.IsGPT4Model(model), }, func(data string) error { - return conn.SendJSON(globals.ChatSegmentResponse{ + if signal := conn.PeekWithType(StopType); signal != nil { + // stop signal from client + return fmt.Errorf("signal") + } + return conn.SendClient(globals.ChatSegmentResponse{ Message: buffer.Write(data), Quota: buffer.GetQuota(), End: false, }) - }); err != nil { + }); err != nil && err.Error() != "signal" { CollectQuota(conn.GetCtx(), user, buffer.GetQuota(), reversible) conn.Send(globals.ChatSegmentResponse{ Message: err.Error(), diff --git a/manager/connection.go b/manager/connection.go new file mode 100644 index 00000000..1cf18106 --- /dev/null +++ b/manager/connection.go @@ -0,0 +1,177 @@ +package manager + +import ( + "chat/globals" + "chat/manager/conversation" + "chat/utils" + "database/sql" + "fmt" + "github.com/gin-gonic/gin" + "github.com/go-redis/redis/v8" +) + +const ( + ChatType = "chat" + StopType = "stop" + RestartType = "restart" + ShareType = "share" +) + +type Stack []*conversation.FormMessage + +type Connection struct { + conn *utils.WebSocket + stack Stack + auth bool + hash string +} + +func NewConnection(conn *utils.WebSocket, auth bool, hash string, bufferSize int) *Connection { + buf := &Connection{ + conn: conn, + auth: auth, + hash: hash, + stack: make(Stack, bufferSize), + } + buf.ReadWorker() + return buf +} + +func (c *Connection) GetConn() *utils.WebSocket { + return c.conn +} + +func (c *Connection) GetCtx() *gin.Context { + return c.conn.GetCtx() +} + +func (c *Connection) GetStack() Stack { + return c.stack +} + +func (c *Connection) ReadWorker() { + go func() { + for { + form := utils.ReadForm[conversation.FormMessage](c.conn) + if form.Type == "" { + form.Type = ChatType + } + + c.Write(form) + + if form == nil { + return + } + } + }() +} + +func (c *Connection) Write(data *conversation.FormMessage) { + c.stack = append(c.stack, data) +} + +func (c *Connection) Read() (*conversation.FormMessage, bool) { + if len(c.stack) == 0 { + return nil, false + } + form := c.stack[0] + c.Skip() + return form, true +} + +func (c *Connection) ReadWithBlock() *conversation.FormMessage { + // return: nil if connection is closed + for { + if form, ok := c.Read(); ok { + return form + } + } +} + +func (c *Connection) Peek() *conversation.FormMessage { + // return nil if no message is received + if len(c.stack) == 0 { + return nil + } + return c.ReadWithBlock() +} + +func (c *Connection) PeekWithType(t string) *conversation.FormMessage { + // skip if type is matched + + if form := c.Peek(); form != nil { + if form.Type == t { + c.Skip() + return form + } + } + + return nil +} + +func (c *Connection) Skip() { + if len(c.stack) == 0 { + return + } + c.stack = c.stack[1:] +} + +func (c *Connection) GetDB() *sql.DB { + return c.conn.GetDB() +} + +func (c *Connection) GetCache() *redis.Client { + return c.conn.GetCache() +} + +func (c *Connection) Send(message globals.ChatSegmentResponse) { + c.conn.Send(message) +} + +func (c *Connection) SendClient(message globals.ChatSegmentResponse) error { + return c.conn.SendJSON(message) +} + +func (c *Connection) Handle(handler func(*conversation.FormMessage) error) { + defer c.conn.DeferClose() + + for { + form := c.ReadWithBlock() + if form == nil { + return + } + + if !c.Lock() { + return + } + + if err := handler(form); err != nil { + return + } + + c.Release() + } +} + +func (c *Connection) Lock() bool { + state := c.conn.IncrRateWithLimit( + c.hash, + utils.Multi[int64](c.auth, globals.ChatMaxThread, globals.AnonymousMaxThread), + 60, + ) + + if !state { + c.conn.Send(globals.ChatSegmentResponse{ + Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread), + End: true, + }) + + return false + } + + return true +} + +func (c *Connection) Release() { + c.conn.DecrRate(c.hash) +} diff --git a/manager/conversation/conversation.go b/manager/conversation/conversation.go index efe8e92a..7de6f60d 100644 --- a/manager/conversation/conversation.go +++ b/manager/conversation/conversation.go @@ -17,10 +17,11 @@ type Conversation struct { Message []globals.Message `json:"message"` Model string `json:"model"` EnableWeb bool `json:"enable_web"` + Shared bool `json:"shared"` } type FormMessage struct { - Type string `json:"type"` // ping + Type string `json:"type"` Message string `json:"message"` Web bool `json:"web"` Model string `json:"model"` @@ -117,7 +118,7 @@ func (c *Conversation) GetMessageById(id int) globals.Message { return c.Message[id] } -func (c *Conversation) GetMessageSize() int { +func (c *Conversation) GetMessageLength() int { return len(c.Message) } @@ -140,6 +141,18 @@ func (c *Conversation) AddMessage(message globals.Message) { c.Message = append(c.Message, message) } +func (c *Conversation) AddMessages(messages []globals.Message) { + c.Message = append(c.Message, messages...) +} + +func (c *Conversation) InsertMessage(message globals.Message, index int) { + c.Message = append(c.Message[:index], append([]globals.Message{message}, c.Message[index:]...)...) +} + +func (c *Conversation) InsertMessages(messages []globals.Message, index int) { + c.Message = append(c.Message[:index], append(messages, c.Message[index:]...)...) +} + func (c *Conversation) AddMessageFromUser(message string) { c.AddMessage(globals.Message{ Role: "user", @@ -231,3 +244,13 @@ func (c *Conversation) SaveResponse(db *sql.DB, message string) { c.AddMessageFromAssistant(message) c.SaveConversation(db) } + +func (c *Conversation) RemoveMessage(index int) globals.Message { + message := c.Message[index] + c.Message = append(c.Message[:index], c.Message[index+1:]...) + return message +} + +func (c *Conversation) RemoveLatestMessage() globals.Message { + return c.RemoveMessage(len(c.Message) - 1) +} diff --git a/manager/conversation/shared.go b/manager/conversation/shared.go index 5597021f..25c1a47b 100644 --- a/manager/conversation/shared.go +++ b/manager/conversation/shared.go @@ -130,7 +130,7 @@ func UseSharedConversation(db *sql.DB, user *auth.User, hash string) *Conversati } func (c *Conversation) LoadSharing(db *sql.DB, hash string) { - if strings.TrimSpace(hash) == "" { + if strings.TrimSpace(hash) == "" || c.Shared == true { return } @@ -139,6 +139,7 @@ func (c *Conversation) LoadSharing(db *sql.DB, hash string) { return } - c.Message = shared.Messages - c.Name = shared.Name + c.InsertMessages(shared.Messages, 0) + c.SetName(db, shared.Name) + c.Shared = true } diff --git a/manager/manager.go b/manager/manager.go index 15fed160..10fb2ee4 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -2,7 +2,6 @@ package manager import ( "chat/auth" - "chat/globals" "chat/manager/conversation" "chat/utils" "fmt" @@ -17,32 +16,9 @@ type WebsocketAuthForm struct { Ref string `json:"ref"` } -func EventHandler(conn *utils.WebSocket, instance *conversation.Conversation, user *auth.User) string { +func EventHandler(conn *Connection, instance *conversation.Conversation, user *auth.User) string { if strings.HasPrefix(instance.GetLatestMessage(), "/image") { - if user == nil { - conn.Send(globals.ChatSegmentResponse{ - Message: "You need to login to use this feature.", - End: true, - }) - return "You need to login to use this feature." - } - - prompt := strings.TrimSpace(strings.TrimPrefix(instance.GetLatestMessage(), "/image")) - - if response, err := GenerateImage(conn.GetCtx(), user, prompt); err != nil { - conn.Send(globals.ChatSegmentResponse{ - Message: err.Error(), - End: true, - }) - return err.Error() - } else { - conn.Send(globals.ChatSegmentResponse{ - Quota: 1., - Message: response, - End: true, - }) - return response - } + return ImageHandler(conn, user, instance) } else { return ChatHandler(conn, user, instance) } @@ -53,7 +29,6 @@ func ChatAPI(c *gin.Context) { if conn = utils.NewWebsocket(c, false); conn == nil { return } - defer conn.DeferClose() db := utils.GetDBFromContext(c) @@ -74,28 +49,26 @@ func ChatAPI(c *gin.Context) { c.ClientIP(), ))) - for { - var form *conversation.FormMessage - if form = utils.ReadForm[conversation.FormMessage](conn); form == nil { - return - } - - if instance.HandleMessage(db, form) { - if !conn.IncrRateWithLimit( - hash, - utils.Multi[int64](authenticated, globals.ChatMaxThread, globals.AnonymousMaxThread), - 60, - ) { - conn.Send(globals.ChatSegmentResponse{ - Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread), - End: true, - }) - return + buf := NewConnection(conn, authenticated, hash, 10) + buf.Handle(func(form *conversation.FormMessage) error { + switch form.Type { + case ChatType: + if instance.HandleMessage(db, form) { + response := EventHandler(buf, instance, user) + instance.SaveResponse(db, response) } - - response := EventHandler(conn, instance, user) - conn.DecrRate(hash) + case StopType: + break + case ShareType: + instance.LoadSharing(db, form.Message) + case RestartType: + if message := instance.RemoveLatestMessage(); message.Role != "user" { + return fmt.Errorf("message type error") + } + response := EventHandler(buf, instance, user) instance.SaveResponse(db, response) } - } + + return nil + }) }