diff --git a/manager/transhipment.go b/manager/transhipment.go index 018e561b..41f20a92 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -28,6 +28,7 @@ type TranshipmentForm struct { TopK *int `json:"top_k"` Tools *globals.FunctionTools ToolChoice *interface{} + Official bool `json:"official"` } type Choice struct { @@ -49,7 +50,7 @@ type TranshipmentResponse struct { Model string `json:"model"` Choices []Choice `json:"choices"` Usage Usage `json:"usage"` - Quota float32 `json:"quota"` + Quota *float32 `json:"quota,omitempty"` } type ChoiceDelta struct { @@ -65,7 +66,7 @@ type TranshipmentStreamResponse struct { Model string `json:"model"` Choices []ChoiceDelta `json:"choices"` Usage Usage `json:"usage"` - Quota float32 `json:"quota"` + Quota *float32 `json:"quota,omitempty"` } func ModelAPI(c *gin.Context) { @@ -116,6 +117,11 @@ func TranshipmentAPI(c *gin.Context) { } } + if strings.HasSuffix(form.Model, "-official") { + form.Model = strings.TrimSuffix(form.Model, "-official") + form.Official = true + } + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model) if !check { c.JSON(http.StatusForbidden, gin.H{ @@ -185,7 +191,7 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, CompletionTokens: buffer.CountOutputToken(), TotalTokens: buffer.CountToken(), }, - Quota: buffer.GetQuota(), + Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())), }) } @@ -210,7 +216,7 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, CompletionTokens: utils.MultiF(end, func() int { return buffer.CountOutputToken() }, 0), TotalTokens: utils.MultiF(end, func() int { return buffer.CountToken() }, 0), }, - Quota: buffer.GetQuota(), + Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())), } }