feat: 修复智谱GLM-4V流模式异常

This commit is contained in:
1808837298@qq.com 2024-03-01 22:31:08 +08:00
parent 413d4f0a66
commit 84cac72a45
7 changed files with 17 additions and 13 deletions

View File

@ -71,10 +71,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openaiStreamHandler(c, resp, info.RelayMode) err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -16,7 +16,7 @@ import (
"time" "time"
) )
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -111,7 +111,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String()
} }
func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
var textResponse dto.TextResponse var textResponse dto.TextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
package zhipu_v4 package zhipu_4v
import ( import (
"errors" "errors"
@ -8,7 +8,9 @@ import (
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
) )
type Adaptor struct { type Adaptor struct {
@ -41,9 +43,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = zhipuStreamHandler(c, resp) var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = zhipuHandler(c, resp) err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -1,7 +1,7 @@
package zhipu_v4 package zhipu_4v
var ModelList = []string{ var ModelList = []string{
"glm-4", "glm-4v", "glm-3-turbo", "glm-4", "glm-4v", "glm-3-turbo",
} }
var ChannelName = "zhipu_v4" var ChannelName = "zhipu_4v"

View File

@ -1,4 +1,4 @@
package zhipu_v4 package zhipu_4v
import ( import (
"one-api/dto" "one-api/dto"

View File

@ -1,4 +1,4 @@
package zhipu_v4 package zhipu_4v
import ( import (
"bufio" "bufio"

View File

@ -11,7 +11,7 @@ import (
"one-api/relay/channel/tencent" "one-api/relay/channel/tencent"
"one-api/relay/channel/xunfei" "one-api/relay/channel/xunfei"
"one-api/relay/channel/zhipu" "one-api/relay/channel/zhipu"
"one-api/relay/channel/zhipu_v4" "one-api/relay/channel/zhipu_4v"
"one-api/relay/constant" "one-api/relay/constant"
) )
@ -38,7 +38,7 @@ func GetAdaptor(apiType int) channel.Adaptor {
case constant.APITypeZhipu: case constant.APITypeZhipu:
return &zhipu.Adaptor{} return &zhipu.Adaptor{}
case constant.APITypeZhipu_v4: case constant.APITypeZhipu_v4:
return &zhipu_v4.Adaptor{} return &zhipu_4v.Adaptor{}
} }
return nil return nil
} }