Skip to content

Commit

Permalink
add key transhipment feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Oct 2, 2023
1 parent 98a40fb commit 1bae9b4
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 15 deletions.
11 changes: 8 additions & 3 deletions adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type ChatProps struct {
Reversible bool
Infinity bool
Message []globals.Message
Token int
}

func NewChatRequest(props *ChatProps, hook globals.Hook) error {
Expand All @@ -31,20 +32,24 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
props.Model,
),
Message: props.Message,
Token: utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000),
Token: utils.Multi(
props.Token == 0,
utils.Multi(globals.IsGPT4Model(props.Model) || props.Reversible || props.Infinity, -1, 2000),
props.Token,
),
}, hook)

} else if globals.IsClaudeModel(props.Model) {
return claude.NewChatInstanceFromConfig().CreateStreamChatRequest(&claude.ChatProps{
Model: props.Model,
Message: props.Message,
Token: 50000,
Token: utils.Multi(props.Token == 0, 50000, props.Token),
}, hook)

} else if globals.IsSparkDeskModel(props.Model) {
return sparkdesk.NewChatInstance().CreateStreamChatRequest(&sparkdesk.ChatProps{
Message: props.Message,
Token: 2048,
Token: utils.Multi(props.Token == 0, 2500, props.Token),
}, hook)

} else if globals.IsPalm2Model(props.Model) {
Expand Down
1 change: 1 addition & 0 deletions manager/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ import "github.com/gin-gonic/gin"

func Register(app *gin.Engine) {
app.GET("/chat", ChatAPI)
app.POST("/v1/completions", TranshipmentAPI)
}
206 changes: 206 additions & 0 deletions manager/transhipment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
package manager

import (
"chat/adapter"
"chat/auth"
"chat/globals"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"time"
)

type TranshipmentForm struct {
Model string `json:"model" binding:"required"`
Messages []globals.Message `json:"messages" binding:"required"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
}

type Choice struct {
Index int `json:"index"`
Message globals.Message `json:"message"`
FinishReason string `json:"finish_reason"`
}

type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}

type TranshipmentResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
Quota float32 `json:"quota"`
}

type ChoiceDelta struct {
Index int `json:"index"`
Delta globals.Message `json:"delta"`
FinishReason interface{} `json:"finish_reason"`
}

type TranshipmentStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChoiceDelta `json:"choices"`
Usage Usage `json:"usage"`
Quota float32 `json:"quota"`
}

func TranshipmentAPI(c *gin.Context) {
username := utils.GetUserFromContext(c)
if username == "" {
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
return
}

if utils.GetAgentFromContext(c) != "api" {
c.AbortWithStatusJSON(403, gin.H{
"code": 403,
"message": "Access denied. Please provide correct api key.",
})
return
}

var form TranshipmentForm
if err := c.ShouldBindJSON(&form); err != nil {
c.JSON(400, gin.H{
"status": false,
"error": "invalid request body",
"reason": err.Error(),
})
return
}

db := utils.GetDBFromContext(c)
cache := utils.GetCacheFromContext(c)
user := &auth.User{
Username: username,
}
id := utils.Md5Encrypt(username + form.Model + time.Now().String())
created := time.Now().Unix()

reversible := globals.IsGPT4NativeModel(form.Model) && auth.CanEnableSubscription(db, cache, user)

if !auth.CanEnableModelWithSubscription(db, user, form.Model, reversible) {
c.JSON(http.StatusForbidden, gin.H{
"status": false,
"error": "quota exceeded",
"reason": "not enough quota to use this model",
})
return
}

if form.Stream {
sendStreamTranshipmentResponse(c, form, id, created, user, reversible)
} else {
sendTranshipmentResponse(c, form, id, created, user, reversible)
}
}

func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) {
buffer := utils.NewBuffer(form.Model, form.Messages)
err := adapter.NewChatRequest(&adapter.ChatProps{
Model: form.Model,
Message: form.Messages,
Reversible: reversible && globals.IsGPT4Model(form.Model),
Token: form.MaxTokens,
}, func(data string) error {
buffer.Write(data)
return nil
})
if err != nil {
fmt.Println(fmt.Sprintf("error from chat request api: %s", err.Error()))
}

CollectQuota(c, user, buffer.GetQuota(), reversible)
c.JSON(http.StatusOK, TranshipmentResponse{
Id: id,
Object: "chat.completion",
Created: created,
Model: form.Model,
Choices: []Choice{
{
Index: 0,
Message: globals.Message{Role: "assistant", Content: buffer.ReadWithDefault(defaultMessage)},
},
},
Usage: Usage{
PromptTokens: int(buffer.CountInputToken()),
CompletionTokens: int(buffer.CountOutputToken()),
TotalTokens: int(buffer.CountToken()),
},
Quota: buffer.GetQuota(),
})
}

func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, data string, buffer *utils.Buffer, end bool) TranshipmentStreamResponse {
return TranshipmentStreamResponse{
Id: id,
Object: "chat.completion",
Created: created,
Model: form.Model,
Choices: []ChoiceDelta{
{
Index: 0,
Delta: globals.Message{
Role: "assistant",
Content: data,
},
FinishReason: utils.Multi[interface{}](end, "stop", nil),
},
},
Usage: Usage{
PromptTokens: int(buffer.CountInputToken()),
CompletionTokens: int(buffer.CountOutputToken()),
TotalTokens: int(buffer.CountToken()),
},
Quota: buffer.GetQuota(),
}
}

func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, created int64, user *auth.User, reversible bool) {
channel := make(chan TranshipmentStreamResponse)

go func() {
buffer := utils.NewBuffer(form.Model, form.Messages)
if err := adapter.NewChatRequest(&adapter.ChatProps{
Model: form.Model,
Message: form.Messages,
Reversible: reversible && globals.IsGPT4Model(form.Model),
Token: form.MaxTokens,
}, func(data string) error {
channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false)
return nil
}); err != nil {
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible)
return
}

channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible)
return
}()

c.Stream(func(w io.Writer) bool {
if resp, ok := <-channel; ok {
c.SSEvent("message", resp)
return true
}
return false
})
}
9 changes: 7 additions & 2 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"chat/auth"
"chat/utils"
"github.com/gin-gonic/gin"
"net/http"
"strings"
)

Expand All @@ -23,13 +24,13 @@ func ProcessKey(c *gin.Context, key string) bool {
cache := utils.GetCacheFromContext(c)

if utils.IsInBlackList(cache, addr) {
c.AbortWithStatusJSON(200, gin.H{
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"code": 403,
"message": "ip in black list",
})
return false
}

if user := auth.ParseApiKey(c, key); user != nil {
c.Set("auth", true)
c.Set("user", user.Username)
Expand All @@ -39,6 +40,10 @@ func ProcessKey(c *gin.Context, key string) bool {
}

utils.IncrIP(cache, addr)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "Access denied. Please provide correct api key.",
})
return false
}

Expand Down
38 changes: 28 additions & 10 deletions utils/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@ import (
)

type Buffer struct {
Model string `json:"model"`
Quota float32 `json:"quota"`
Data string `json:"data"`
Cursor int `json:"cursor"`
Times int `json:"times"`
Model string `json:"model"`
Quota float32 `json:"quota"`
Data string `json:"data"`
Cursor int `json:"cursor"`
Times int `json:"times"`
History []globals.Message `json:"history"`
}

func NewBuffer(model string, history []globals.Message) *Buffer {
return &Buffer{
Data: "",
Cursor: 0,
Times: 0,
Model: model,
Quota: CountInputToken(model, history),
Data: "",
Cursor: 0,
Times: 0,
Model: model,
Quota: CountInputToken(model, history),
History: history,
}
}

Expand Down Expand Up @@ -73,3 +75,19 @@ func (b *Buffer) ReadWithDefault(_default string) string {
func (b *Buffer) ReadTimes() int {
return b.Times
}

func (b *Buffer) ReadHistory() []globals.Message {
return b.History
}

func (b *Buffer) CountInputToken() float32 {
return CountInputToken(b.Model, b.ReadHistory())
}

func (b *Buffer) CountOutputToken() float32 {
return CountOutputToken(b.Model, b.ReadTimes())
}

func (b *Buffer) CountToken() float32 {
return b.CountInputToken() + b.CountOutputToken()
}

0 comments on commit 1bae9b4

Please sign in to comment.