feat: 支持智谱GLM-4V

This commit is contained in:
1808837298@qq.com
2024-02-29 18:31:03 +08:00
parent a3687b72f8
commit b4645d1019
16 changed files with 483 additions and 578 deletions

View File

@@ -0,0 +1,57 @@
package zhipu_v4
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
)
type Adaptor struct {
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
token := getZhipuToken(info.ApiKey)
req.Header.Set("Authorization", token)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Zhipu(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream {
err, usage = zhipuStreamHandler(c, resp)
} else {
err, usage = zhipuHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package zhipu_v4
var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo",
}
var ChannelName = "zhipu_v4"

View File

@@ -0,0 +1,59 @@
package zhipu_v4
import (
"one-api/dto"
"time"
)
// type ZhipuMessage struct {
// Role string `json:"role,omitempty"`
// Content string `json:"content,omitempty"`
// ToolCalls any `json:"tool_calls,omitempty"`
// ToolCallId any `json:"tool_call_id,omitempty"`
// }
//
// type ZhipuRequest struct {
// Model string `json:"model"`
// Stream bool `json:"stream,omitempty"`
// Messages []ZhipuMessage `json:"messages"`
// Temperature float64 `json:"temperature,omitempty"`
// TopP float64 `json:"top_p,omitempty"`
// MaxTokens int `json:"max_tokens,omitempty"`
// Stop []string `json:"stop,omitempty"`
// RequestId string `json:"request_id,omitempty"`
// Tools any `json:"tools,omitempty"`
// ToolChoice any `json:"tool_choice,omitempty"`
// }
//
// type ZhipuV4TextResponseChoice struct {
// Index int `json:"index"`
// ZhipuMessage `json:"message"`
// FinishReason string `json:"finish_reason"`
// }
type ZhipuV4Response struct {
Id string `json:"id"`
Created int64 `json:"created"`
Model string `json:"model"`
TextResponseChoices []dto.OpenAITextResponseChoice `json:"choices"`
Usage dto.Usage `json:"usage"`
Error dto.OpenAIError `json:"error"`
}
//
//type ZhipuV4StreamResponseChoice struct {
// Index int `json:"index,omitempty"`
// Delta ZhipuMessage `json:"delta"`
// FinishReason *string `json:"finish_reason,omitempty"`
//}
type ZhipuV4StreamResponse struct {
Id string `json:"id"`
Created int64 `json:"created"`
Choices []dto.ChatCompletionsStreamResponseChoice `json:"choices"`
Usage dto.Usage `json:"usage"`
}
type tokenData struct {
Token string
ExpiryTime time.Time
}

View File

@@ -0,0 +1,262 @@
package zhipu_v4
import (
"bufio"
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/service"
"strings"
"sync"
"time"
)
// https://open.bigmodel.cn/doc/api#chatglm_std
// chatglm_std, chatglm_lite
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600
func getZhipuToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey)
if ok {
tokenData := data.(tokenData)
if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token
}
}
split := strings.Split(apikey, ".")
if len(split) != 2 {
common.SysError("invalid zhipu key: " + apikey)
return ""
}
id := split[0]
secret := split[1]
expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6
expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second)
timestamp := time.Now().UnixNano() / 1e6
payload := jwt.MapClaims{
"api_key": id,
"exp": expMillis,
"timestamp": timestamp,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
token.Header["alg"] = "HS256"
token.Header["sign_type"] = "SIGN"
tokenString, err := token.SignedString([]byte(secret))
if err != nil {
return ""
}
zhipuTokens.Store(apikey, tokenData{
Token: tokenString,
ExpiryTime: expiryTime,
})
return tokenString
}
func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
Stop: Stop,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
}
//func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse {
// fullTextResponse := dto.OpenAITextResponse{
// Id: response.Id,
// Object: "chat.completion",
// Created: common.GetTimestamp(),
// Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)),
// Usage: response.Usage,
// }
// for i, choice := range response.TextResponseChoices {
// content, _ := json.Marshal(strings.Trim(choice.Content, "\""))
// openaiChoice := dto.OpenAITextResponseChoice{
// Index: i,
// Message: dto.Message{
// Role: choice.Role,
// Content: content,
// },
// FinishReason: "",
// }
// if i == len(response.TextResponseChoices)-1 {
// openaiChoice.FinishReason = "stop"
// }
// fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice)
// }
// return &fullTextResponse
//}
func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content
choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role
choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls
choice.Index = zhipuResponse.Choices[0].Index
choice.FinishReason = zhipuResponse.Choices[0].FinishReason
response := dto.ChatCompletionsStreamResponse{
Id: zhipuResponse.Id,
Object: "chat.completion.chunk",
Created: zhipuResponse.Created,
Model: "glm-4",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) {
response := streamResponseZhipu2OpenAI(zhipuResponse)
return response, &zhipuResponse.Usage
}
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var usage *dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
}
stopChan <- true
}()
service.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
var streamResponse ZhipuV4StreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
}
var response *dto.ChatCompletionsStreamResponse
if strings.Contains(data, "prompt_tokens") {
response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse)
} else {
response = streamResponseZhipu2OpenAI(&streamResponse)
}
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, usage
}
func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse ZhipuV4Response
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the HTTPClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}