diff --git a/adapter/adapter.go b/adapter/adapter.go index 5d6d7d10..0f75bf5c 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -16,6 +16,7 @@ type ChatProps struct { Reversible bool Infinity bool Message []globals.Message + Token int } func NewChatRequest(props *ChatProps, hook globals.Hook) error { @@ -31,20 +32,24 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error { props.Model, ), Message: props.Message, - Token: utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000), + Token: utils.Multi( + props.Token == 0, + utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000), + props.Token, + ), }, hook) } else if globals.IsClaudeModel(props.Model) { return claude.NewChatInstanceFromConfig().CreateStreamChatRequest(&claude.ChatProps{ Model: props.Model, Message: props.Message, - Token: 50000, + Token: utils.Multi(props.Token == 0, 50000, props.Token), }, hook) } else if globals.IsSparkDeskModel(props.Model) { return sparkdesk.NewChatInstance().CreateStreamChatRequest(&sparkdesk.ChatProps{ Message: props.Message, - Token: 2048, + Token: utils.Multi(props.Token == 0, 2500, props.Token), }, hook) } else if globals.IsPalm2Model(props.Model) { diff --git a/manager/router.go b/manager/router.go index ff615a60..f6be4967 100644 --- a/manager/router.go +++ b/manager/router.go @@ -4,4 +4,5 @@ import "github.com/gin-gonic/gin" func Register(app *gin.Engine) { app.GET("/chat", ChatAPI) + app.POST("/v1/completions", TranshipmentAPI) } diff --git a/manager/transhipment.go b/manager/transhipment.go new file mode 100644 index 00000000..f4e21da3 --- /dev/null +++ b/manager/transhipment.go @@ -0,0 +1,206 @@ +package manager + +import ( + "chat/adapter" + "chat/auth" + "chat/globals" + "chat/utils" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "time" +) + +type TranshipmentForm struct { + Model string `json:"model" binding:"required"` + Messages []globals.Message `json:"messages" binding:"required"` + Stream bool `json:"stream"` + MaxTokens int `json:"max_tokens"` +} + +type Choice struct { + Index int `json:"index"` + Message globals.Message `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type TranshipmentResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` + Quota float32 `json:"quota"` +} + +type ChoiceDelta struct { + Index int `json:"index"` + Delta globals.Message `json:"delta"` + FinishReason interface{} `json:"finish_reason"` +} + +type TranshipmentStreamResponse struct { + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChoiceDelta `json:"choices"` + Usage Usage `json:"usage"` + Quota float32 `json:"quota"` +} + +func TranshipmentAPI(c *gin.Context) { + username := utils.GetUserFromContext(c) + if username == "" { + c.AbortWithStatusJSON(403, gin.H{ + "code": 403, + "message": "Access denied. Please provide correct api key.", + }) + return + } + + if utils.GetAgentFromContext(c) != "api" { + c.AbortWithStatusJSON(403, gin.H{ + "code": 403, + "message": "Access denied. Please provide correct api key.", + }) + return + } + + var form TranshipmentForm + if err := c.ShouldBindJSON(&form); err != nil { + c.JSON(400, gin.H{ + "status": false, + "error": "invalid request body", + "reason": err.Error(), + }) + return + } + + db := utils.GetDBFromContext(c) + cache := utils.GetCacheFromContext(c) + user := &auth.User{ + Username: username, + } + id := utils.Md5Encrypt(username + form.Model + time.Now().String()) + created := time.Now().Unix() + + reversible := globals.IsGPT4NativeModel(form.Model) && auth.CanEnableSubscription(db, cache, user) + + if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) { + c.JSON(http.StatusForbidden, gin.H{ + "status": false, + "error": "quota exceeded", + "reason": "not enough quota to use this model", + }) + return + } + + if form.Stream { + sendStreamTranshipmentResponse(c, form, id, created, user, reversible) + } else { + sendTranshipmentResponse(c, form, id, created, user, reversible) + } +} + +func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) { + buffer := utils.NewBuffer(form.Model, form.Messages) + err := adapter.NewChatRequest(&adapter.ChatProps{ + Model: form.Model, + Message: form.Messages, + Reversible: reversible && globals.IsGPT4Model(form.Model), + Token: form.MaxTokens, + }, func(data string) error { + buffer.Write(data) + return nil + }) + if err != nil { + fmt.Println(fmt.Sprintf("error from chat request api: %s", err.Error())) + } + + CollectQuota(c, user, buffer.GetQuota(), reversible) + c.JSON(http.StatusOK, TranshipmentResponse{ + Id: id, + Object: "chat.completion", + Created: created, + Model: form.Model, + Choices: []Choice{ + { + Index: 0, + Message: globals.Message{Role: "assistant", Content: buffer.ReadWithDefault(defaultMessage)}, + }, + }, + Usage: Usage{ + PromptTokens: int(buffer.CountInputToken()), + CompletionTokens: int(buffer.CountOutputToken()), + TotalTokens: int(buffer.CountToken()), + }, + Quota: buffer.GetQuota(), + }) +} + +func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool) TranshipmentStreamResponse { + return TranshipmentStreamResponse{ + Id: id, + Object: "chat.completion", + Created: created, + Model: form.Model, + Choices: []ChoiceDelta{ + { + Index: 0, + Delta: globals.Message{ + Role: "assistant", + Content: data, + }, + FinishReason: utils.Multi[interface{}](end, "stop", nil), + }, + }, + Usage: Usage{ + PromptTokens: int(buffer.CountInputToken()), + CompletionTokens: int(buffer.CountOutputToken()), + TotalTokens: int(buffer.CountToken()), + }, + Quota: buffer.GetQuota(), + } +} + +func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) { + channel := make(chan TranshipmentStreamResponse) + + go func() { + buffer := utils.NewBuffer(form.Model, form.Messages) + if err := adapter.NewChatRequest(&adapter.ChatProps{ + Model: form.Model, + Message: form.Messages, + Reversible: reversible && globals.IsGPT4Model(form.Model), + Token: form.MaxTokens, + }, func(data string) error { + channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false) + return nil + }); err != nil { + channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) + CollectQuota(c, user, buffer.GetQuota(), reversible) + return + } + + channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true) + CollectQuota(c, user, buffer.GetQuota(), reversible) + return + }() + + c.Stream(func(w io.Writer) bool { + if resp, ok := <-channel; ok { + c.SSEvent("message", resp) + return true + } + return false + }) +} diff --git a/middleware/auth.go b/middleware/auth.go index ef3fbb67..3cf31a9c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,6 +4,7 @@ import ( "chat/auth" "chat/utils" "github.com/gin-gonic/gin" + "net/http" "strings" ) @@ -23,13 +24,13 @@ func ProcessKey(c *gin.Context, key string) bool { cache := utils.GetCacheFromContext(c) if utils.IsInBlackList(cache, addr) { - c.AbortWithStatusJSON(200, gin.H{ + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ "code": 403, "message": "ip in black list", }) return false } - + if user := auth.ParseApiKey(c, key); user != nil { c.Set("auth", true) c.Set("user", user.Username) @@ -39,6 +40,10 @@ func ProcessKey(c *gin.Context, key string) bool { } utils.IncrIP(cache, addr) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Access denied. Please provide correct api key.", + }) return false } diff --git a/utils/buffer.go b/utils/buffer.go index c3af1b0f..89908b25 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -6,20 +6,22 @@ import ( ) type Buffer struct { - Model string `json:"model"` - Quota float32 `json:"quota"` - Data string `json:"data"` - Cursor int `json:"cursor"` - Times int `json:"times"` + Model string `json:"model"` + Quota float32 `json:"quota"` + Data string `json:"data"` + Cursor int `json:"cursor"` + Times int `json:"times"` + History []globals.Message `json:"history"` } func NewBuffer(model string, history []globals.Message) *Buffer { return &Buffer{ - Data: "", - Cursor: 0, - Times: 0, - Model: model, - Quota: CountInputToken(model, history), + Data: "", + Cursor: 0, + Times: 0, + Model: model, + Quota: CountInputToken(model, history), + History: history, } } @@ -73,3 +75,19 @@ func (b *Buffer) ReadWithDefault(_default string) string { func (b *Buffer) ReadTimes() int { return b.Times } + +func (b *Buffer) ReadHistory() []globals.Message { + return b.History +} + +func (b *Buffer) CountInputToken() float32 { + return CountInputToken(b.Model, b.ReadHistory()) +} + +func (b *Buffer) CountOutputToken() float32 { + return CountOutputToken(b.Model, b.ReadTimes()) +} + +func (b *Buffer) CountToken() float32 { + return b.CountInputToken() + b.CountOutputToken() +}