Skip to content

Commit

Permalink
feat: update zhipuai chatglm channel api format v3 to v4 and support …
Browse files Browse the repository at this point in the history
…glm-4 & glm-4v (#147)

feat: update zhipuai chatglm channel api format v3 to v4 and support glm-4 & glm-4v (#147)
  • Loading branch information
Sh1n3zZ committed Mar 28, 2024
1 parent d0bf977 commit 4d0d92d
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 101 deletions.
5 changes: 0 additions & 5 deletions adapter/azure/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ type ChatInstance struct {
Resource string
}

type InstanceProps struct {
Model string
Plan bool
}

func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
Expand Down
5 changes: 0 additions & 5 deletions adapter/openai/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ type ChatInstance struct {
ApiKey string
}

type InstanceProps struct {
Model string
Plan bool
}

func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
Expand Down
157 changes: 115 additions & 42 deletions adapter/zhipuai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,144 @@ import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"strings"
"regexp"
)

func (c *ChatInstance) GetChatEndpoint(model string) string {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model))
func (c *ChatInstance) GetChatEndpoint() string {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", c.GetEndpoint())
}

func (c *ChatInstance) GetModel(model string) string {
func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string {
result := ""
for _, message := range messages {
result += fmt.Sprintf("%s: %s\n", message.Role, message.Content)
}
return result
}

func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string {
if len(props.Message) == 0 {
return ""
}

return props.Message[len(props.Message)-1].Content
}

func (c *ChatInstance) ConvertModel(model string) string {
// for v3 legacy adapter
switch model {
case globals.ZhiPuChatGLMTurbo:
return ChatGLMTurbo
return GLMTurbo
case globals.ZhiPuChatGLMPro:
return ChatGLMPro
return GLMPro
case globals.ZhiPuChatGLMStd:
return ChatGLMStd
return GLMStd
case globals.ZhiPuChatGLMLite:
return ChatGLMLite
return GLMLite
default:
return ChatGLMStd
return GLMStd
}
}

func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message {
messages = utils.DeepCopy[[]globals.Message](messages)
for i := range messages {
if messages[i].Role == globals.Tool {
continue
func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} {
if props.Model == globals.GPT3TurboInstruct {
// for completions
return CompletionRequest{
Model: c.ConvertModel(props.Model),
Prompt: c.GetCompletionPrompt(props.Message),
MaxToken: props.MaxTokens,
Stream: stream,
}
}

if messages[i].Role == globals.System {
messages[i].Role = globals.User
}
messages := formatMessages(props)

// chatglm top_p should be (0.0, 1.0) and cannot be 0 or 1
if props.TopP != nil && *props.TopP >= 1.0 {
props.TopP = utils.ToPtr[float32](0.99)
} else if props.TopP != nil && *props.TopP <= 0.0 {
props.TopP = utils.ToPtr[float32](0.01)
}
return messages
}

func (c *ChatInstance) GetBody(props *adaptercommon.ChatProps) ChatRequest {
return ChatRequest{
Prompt: c.FormatMessages(props.Message),
TopP: props.TopP,
Temperature: props.Temperature,
Model: props.Model,
Messages: messages,
MaxToken: props.MaxTokens,
Stream: stream,
PresencePenalty: props.PresencePenalty,
FrequencyPenalty: props.FrequencyPenalty,
Temperature: props.Temperature,
TopP: props.TopP,
Tools: props.Tools,
ToolChoice: props.ToolChoice,
}
}

func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error {
return utils.EventSource(
"POST",
c.GetChatEndpoint(props.Model),
map[string]string{
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Authorization": c.GetToken(),
},
ChatRequest{
Prompt: c.FormatMessages(props.Message),
// CreateChatRequest is the native http request body for chatglm
func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) {
res, err := utils.Post(
c.GetChatEndpoint(),
c.GetHeader(),
c.GetChatBody(props, false),
props.Proxy,
)

if err != nil || res == nil {
return "", fmt.Errorf("chatglm error: %s", err.Error())
}

data := utils.MapToStruct[ChatResponse](res)
if data == nil {
return "", fmt.Errorf("chatglm error: cannot parse response")
} else if data.Error.Message != "" {
return "", fmt.Errorf("chatglm error: %s", data.Error.Message)
}
return data.Choices[0].Message.Content, nil
}

func hideRequestId(message string) string {
// xxx (request id: 2024020311120561344953f0xfh0TX)

exp := regexp.MustCompile(`\(request id: [a-zA-Z0-9]+\)`)
return exp.ReplaceAllString(message, "")
}

// CreateStreamChatRequest is the stream response body for chatglm
func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error {
ticks := 0
err := utils.EventScanner(&utils.EventScannerProps{
Method: "POST",
Uri: c.GetChatEndpoint(),
Headers: c.GetHeader(),
Body: c.GetChatBody(props, true),
Callback: func(data string) error {
ticks += 1

partial, err := c.ProcessLine(data, false)
if err != nil {
return err
}
return callback(partial)
},
func(data string) error {
if !strings.HasPrefix(data, "data:") {
return nil
}, props.Proxy)

if err != nil {
if form := processChatErrorResponse(err.Body); form != nil {
if form.Error.Type == "" && form.Error.Message == "" {
return errors.New(utils.ToMarkdownCode("json", err.Body))
}

data = strings.TrimPrefix(data, "data:")
return hook(&globals.Chunk{Content: data})
},
props.Proxy,
)
msg := fmt.Sprintf("%s (code: %s)", form.Error.Message, form.Error.Code)
return errors.New(hideRequestId(msg))
}
return err.Error
}

if ticks == 0 {
return errors.New("no response")
}

return nil
}
139 changes: 139 additions & 0 deletions adapter/zhipuai/processor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package zhipuai

import (
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"errors"
"fmt"
"regexp"
"strings"
)

func formatMessages(props *adaptercommon.ChatProps) interface{} {
if globals.IsVisionModel(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
content, urls := utils.ExtractImages(message.Content, true)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
props.Buffer.AddImage(obj)
if err != nil {
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "...")))
}

if strings.HasPrefix(url, "data:image/") {
// remove base64 image prefix
if idx := strings.Index(url, "base64,"); idx != -1 {
url = url[idx+7:]
}
}

return &MessageContent{
Type: "image_url",
ImageUrl: &ImageUrl{
Url: url,
},
}
})

return Message{
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &content,
}),
Name: message.Name,
FunctionCall: message.FunctionCall,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
}

return Message{
Role: message.Role,
Content: message.Content,
Name: message.Name,
FunctionCall: message.FunctionCall,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
}
})
}

return props.Message
}

func processChatResponse(data string) *ChatStreamResponse {
return utils.UnmarshalForm[ChatStreamResponse](data)
}

func processCompletionResponse(data string) *CompletionResponse {
return utils.UnmarshalForm[CompletionResponse](data)
}

func processChatErrorResponse(data string) *ChatStreamErrorResponse {
return utils.UnmarshalForm[ChatStreamErrorResponse](data)
}

func getChoices(form *ChatStreamResponse) *globals.Chunk {
if len(form.Choices) == 0 {
return &globals.Chunk{Content: ""}
}

choice := form.Choices[0].Delta

return &globals.Chunk{
Content: choice.Content,
ToolCall: choice.ToolCalls,
FunctionCall: choice.FunctionCall,
}
}

func getCompletionChoices(form *CompletionResponse) string {
if len(form.Choices) == 0 {
return ""
}

return form.Choices[0].Text
}

func getRobustnessResult(chunk string) string {
exp := `\"content\":\"(.*?)\"`
compile, err := regexp.Compile(exp)
if err != nil {
return ""
}

matches := compile.FindStringSubmatch(chunk)
if len(matches) > 1 {
return utils.ProcessRobustnessChar(matches[1])
} else {
return ""
}
}

func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals.Chunk, error) {
if isCompletionType {
// chatglm legacy support
if completion := processCompletionResponse(data); completion != nil {
return &globals.Chunk{
Content: getCompletionChoices(completion),
}, nil
}

globals.Warn(fmt.Sprintf("chatglm error: cannot parse completion response: %s", data))
return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response")
}

if form := processChatResponse(data); form != nil {
return getChoices(form), nil
}

if form := processChatErrorResponse(data); form != nil {
return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatglm error: %s (type: %s)", form.Error.Message, form.Error.Type))
}

globals.Warn(fmt.Sprintf("chatglm error: cannot parse chat completion response: %s", data))
return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response")
}
Loading

0 comments on commit 4d0d92d

Please sign in to comment.