feat: support zhipu's ChatGLM (close #289)

This commit is contained in:
JustSong
2023-07-23 11:51:44 +08:00
parent c87e05bfc2
commit 26c6719ea3
9 changed files with 403 additions and 49 deletions

View File

@@ -19,6 +19,7 @@ const (
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
)
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
@@ -84,6 +85,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiType = APITypeBaidu
} else if strings.HasPrefix(textRequest.Model, "PaLM") {
apiType = APITypePaLM
} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
apiType = APITypeZhipu
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
@@ -134,6 +137,12 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
}
var promptTokens int
var completionTokens int
@@ -200,6 +209,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
@@ -221,6 +237,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
@@ -252,11 +271,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream && apiType != APITypeBaidu {
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else {
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
if apiType == APITypeZhipu {
// zhipu's API does not return prompt tokens & completion tokens
promptTokens = textResponse.Usage.TotalTokens
}
}
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
@@ -302,7 +325,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage = *usage
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
@@ -318,7 +343,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage = *usage
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
@@ -327,14 +354,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage = *usage
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := baiduHandler(c, resp)
if err != nil {
return err
}
textResponse.Usage = *usage
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
@@ -350,7 +381,29 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if err != nil {
return err
}
textResponse.Usage = *usage
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default: