package zhipu_4v import ( "bufio" "bytes" "encoding/json" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt" "io" "net/http" "one-api/common" "one-api/dto" "one-api/service" "strings" "sync" "time" ) // https://open.bigmodel.cn/doc/api#chatglm_std // chatglm_std, chatglm_lite // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke var zhipuTokens sync.Map var expSeconds int64 = 24 * 3600 func getZhipuToken(apikey string) string { data, ok := zhipuTokens.Load(apikey) if ok { tokenData := data.(tokenData) if time.Now().Before(tokenData.ExpiryTime) { return tokenData.Token } } split := strings.Split(apikey, ".") if len(split) != 2 { common.SysError("invalid zhipu key: " + apikey) return "" } id := split[0] secret := split[1] expMillis := time.Now().Add(time.Duration(expSeconds)*time.Second).UnixNano() / 1e6 expiryTime := time.Now().Add(time.Duration(expSeconds) * time.Second) timestamp := time.Now().UnixNano() / 1e6 payload := jwt.MapClaims{ "api_key": id, "exp": expMillis, "timestamp": timestamp, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) token.Header["alg"] = "HS256" token.Header["sign_type"] = "SIGN" tokenString, err := token.SignedString([]byte(secret)) if err != nil { return "" } zhipuTokens.Store(apikey, tokenData{ Token: tokenString, ExpiryTime: expiryTime, }) return tokenString } func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { messages := make([]dto.Message, 0, len(request.Messages)) for _, message := range request.Messages { messages = append(messages, dto.Message{ Role: message.Role, Content: message.Content, ToolCalls: message.ToolCalls, ToolCallId: message.ToolCallId, }) } str, ok := request.Stop.(string) var Stop []string if ok { Stop = []string{str} } else { Stop, _ = request.Stop.([]string) } return &dto.GeneralOpenAIRequest{ Model: request.Model, Stream: request.Stream, Messages: messages, Temperature: request.Temperature, TopP: request.TopP, MaxTokens: request.MaxTokens, Stop: Stop, Tools: request.Tools, ToolChoice: request.ToolChoice, } } //func responseZhipu2OpenAI(response *dto.OpenAITextResponse) *dto.OpenAITextResponse { // fullTextResponse := dto.OpenAITextResponse{ // Id: response.Id, // Object: "chat.completion", // Created: common.GetTimestamp(), // Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.TextResponseChoices)), // Usage: response.Usage, // } // for i, choice := range response.TextResponseChoices { // content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) // openaiChoice := dto.OpenAITextResponseChoice{ // Index: i, // Message: dto.Message{ // Role: choice.Role, // Content: content, // }, // FinishReason: "", // } // if i == len(response.TextResponseChoices)-1 { // openaiChoice.FinishReason = "stop" // } // fullTextResponse.Choices = append(fullTextResponse.Choices, openaiChoice) // } // return &fullTextResponse //} func streamResponseZhipu2OpenAI(zhipuResponse *ZhipuV4StreamResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.Content = zhipuResponse.Choices[0].Delta.Content choice.Delta.Role = zhipuResponse.Choices[0].Delta.Role choice.Delta.ToolCalls = zhipuResponse.Choices[0].Delta.ToolCalls choice.Index = zhipuResponse.Choices[0].Index choice.FinishReason = zhipuResponse.Choices[0].FinishReason response := dto.ChatCompletionsStreamResponse{ Id: zhipuResponse.Id, Object: "chat.completion.chunk", Created: zhipuResponse.Created, Model: "glm-4", Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } return &response } func lastStreamResponseZhipuV42OpenAI(zhipuResponse *ZhipuV4StreamResponse) (*dto.ChatCompletionsStreamResponse, *dto.Usage) { response := streamResponseZhipu2OpenAI(zhipuResponse) return response, &zhipuResponse.Usage } func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } if i := strings.Index(string(data), "\n"); i >= 0 { return i + 1, data[0:i], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) dataChan := make(chan string) stopChan := make(chan bool) go func() { for scanner.Scan() { data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format continue } if data[:6] != "data: " && data[:6] != "[DONE]" { continue } dataChan <- data } stopChan <- true }() service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: if strings.HasPrefix(data, "data: [DONE]") { data = data[:12] } // some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r") var streamResponse ZhipuV4StreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) } var response *dto.ChatCompletionsStreamResponse if strings.Contains(data, "prompt_tokens") { response, usage = lastStreamResponseZhipuV42OpenAI(&streamResponse) } else { response = streamResponseZhipu2OpenAI(&streamResponse) } jsonResponse, err := json.Marshal(response) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: return false } }) err := resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, usage } func zhipuHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var textResponse ZhipuV4Response responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } err = json.Unmarshal(responseBody, &textResponse) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } if textResponse.Error.Type != "" { return &dto.OpenAIErrorWithStatusCode{ Error: textResponse.Error, StatusCode: resp.StatusCode, }, nil } // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) // We shouldn't set the header before we parse the response body, because the parse part may fail. // And then we will have to send an error response, but in this case, the header has already been set. // So the HTTPClient will be confused by the response. // For example, Postman will report error, and we cannot check the response at all. for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } return nil, &textResponse.Usage }