mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-15 04:33:42 +08:00
Merge branch 'songquanpeng' into sync_upstream
This commit is contained in:
@@ -39,6 +39,7 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: aliResponse.Model,
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: aliResponse.Usage.InputTokens,
|
||||
@@ -50,6 +51,8 @@ func (aliResponse *AliChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
return
|
||||
}
|
||||
|
||||
const AliEnableSearchModelSuffix = "-internet"
|
||||
|
||||
// 获取聊天请求体
|
||||
func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
@@ -60,11 +63,23 @@ func (p *AliProvider) getChatRequestBody(request *types.ChatCompletionRequest) *
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, AliEnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, AliEnableSearchModelSuffix)
|
||||
}
|
||||
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Model: aliModel,
|
||||
Input: AliInput{
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: AliParameters{
|
||||
EnableSearch: enableSearch,
|
||||
IncrementalOutput: request.Stream,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,7 +101,7 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req)
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
@@ -100,7 +115,9 @@ func (p *AliProvider) ChatAction(request *types.ChatCompletionRequest, isModelMa
|
||||
}
|
||||
|
||||
} else {
|
||||
aliResponse := &AliChatResponse{}
|
||||
aliResponse := &AliChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, aliResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
@@ -128,14 +145,14 @@ func (p *AliProvider) streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ty
|
||||
ID: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Model: aliResponse.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
// 发送流请求
|
||||
func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *AliProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
@@ -181,7 +198,7 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
lastResponseText := ""
|
||||
// lastResponseText := ""
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -196,9 +213,10 @@ func (p *AliProvider) sendStreamRequest(req *http.Request) (usage *types.Usage,
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
aliResponse.Model = model
|
||||
response := p.streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
// response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
// lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
|
||||
@@ -23,10 +23,11 @@ type AliInput struct {
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
@@ -43,6 +44,7 @@ type AliOutput struct {
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
AliError
|
||||
}
|
||||
|
||||
|
||||
@@ -88,13 +88,15 @@ func (p *BaiduProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
usage, errWithCode = p.sendStreamRequest(req)
|
||||
usage, errWithCode = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
baiduChatRequest := &BaiduChatResponse{}
|
||||
baiduChatRequest := &BaiduChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, baiduChatRequest, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
@@ -117,13 +119,13 @@ func (p *BaiduProvider) streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStrea
|
||||
ID: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Model: baiduResponse.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (p *BaiduProvider) sendStreamRequest(req *http.Request, model string) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
defer req.Body.Close()
|
||||
|
||||
usage = &types.Usage{}
|
||||
@@ -180,6 +182,7 @@ func (p *BaiduProvider) sendStreamRequest(req *http.Request) (usage *types.Usage
|
||||
usage.PromptTokens = baiduResponse.Usage.PromptTokens
|
||||
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
|
||||
}
|
||||
baiduResponse.Model = model
|
||||
response := p.streamResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -32,6 +32,7 @@ type BaiduChatResponse struct {
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage *types.Usage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ func (claudeResponse *ClaudeResponse) ResponseHandler(resp *http.Response) (Open
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []types.ChatCompletionChoice{choice},
|
||||
Model: claudeResponse.Model,
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(claudeResponse.Completion, claudeResponse.Model)
|
||||
|
||||
@@ -32,7 +32,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
version = p.Context.GetString("api_version")
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key"))
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
|
||||
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
@@ -7,11 +7,16 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if len(response.Candidates) == 0 {
|
||||
return nil, &types.OpenAIErrorWithStatusCode{
|
||||
@@ -29,6 +34,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
@@ -46,7 +52,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro")
|
||||
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
|
||||
response.Usage.CompletionTokens = completionTokens
|
||||
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
|
||||
|
||||
@@ -98,6 +104,31 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == types.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == types.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
@@ -142,7 +173,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
@@ -155,6 +186,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
||||
|
||||
} else {
|
||||
var geminiResponse = &GeminiChatResponse{
|
||||
Model: request.Model,
|
||||
Usage: &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
},
|
||||
@@ -170,18 +202,18 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
|
||||
|
||||
}
|
||||
|
||||
func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &base.StopFinishReason
|
||||
var response types.ChatCompletionStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
return &response
|
||||
}
|
||||
// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
|
||||
// var choice types.ChatCompletionStreamChoice
|
||||
// choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
// choice.FinishReason = &base.StopFinishReason
|
||||
// var response types.ChatCompletionStreamResponse
|
||||
// response.Object = "chat.completion.chunk"
|
||||
// response.Model = "gemini"
|
||||
// response.Choices = []types.ChatCompletionStreamChoice{choice}
|
||||
// return &response
|
||||
// }
|
||||
|
||||
func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
@@ -235,7 +267,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
@@ -243,7 +275,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
|
||||
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Model: model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
|
||||
@@ -46,6 +46,7 @@ type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
Usage *types.Usage `json:"usage,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
|
||||
@@ -29,6 +29,7 @@ type PalmProvider struct {
|
||||
func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["x-goog-api-key"] = p.Context.GetString("api_key")
|
||||
|
||||
return headers
|
||||
}
|
||||
@@ -37,5 +38,5 @@ func (p *PalmProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
func (p *PalmProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s?key=%s", baseURL, requestURL, p.Context.GetString("api_key"))
|
||||
return fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func (palmResponse *PaLMChatResponse) ResponseHandler(resp *http.Response) (Open
|
||||
palmResponse.Usage.TotalTokens = palmResponse.Usage.PromptTokens + completionTokens
|
||||
|
||||
fullTextResponse.Usage = palmResponse.Usage
|
||||
fullTextResponse.Model = palmResponse.Model
|
||||
|
||||
return fullTextResponse, nil
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func (TencentResponse *TencentChatResponse) ResponseHandler(resp *http.Response)
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: TencentResponse.Usage,
|
||||
Model: TencentResponse.Model,
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
choice := types.ChatCompletionChoice{
|
||||
@@ -100,7 +101,7 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
|
||||
|
||||
if request.Stream {
|
||||
var responseText string
|
||||
errWithCode, responseText = p.sendStreamRequest(req)
|
||||
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
@@ -112,7 +113,9 @@ func (p *TencentProvider) ChatAction(request *types.ChatCompletionRequest, isMod
|
||||
usage.TotalTokens = promptTokens + usage.CompletionTokens
|
||||
|
||||
} else {
|
||||
tencentResponse := &TencentChatResponse{}
|
||||
tencentResponse := &TencentChatResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, tencentResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
@@ -128,7 +131,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
||||
response := types.ChatCompletionStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
Model: TencentResponse.Model,
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice types.ChatCompletionStreamChoice
|
||||
@@ -141,7 +144,7 @@ func (p *TencentProvider) streamResponseTencent2OpenAI(TencentResponse *TencentC
|
||||
return &response
|
||||
}
|
||||
|
||||
func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
func (p *TencentProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
|
||||
defer req.Body.Close()
|
||||
// 发送请求
|
||||
resp, err := common.HttpClient.Do(req)
|
||||
@@ -195,6 +198,7 @@ func (p *TencentProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErr
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
TencentResponse.Model = model
|
||||
response := p.streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
|
||||
@@ -58,4 +58,5 @@ type TencentChatResponse struct {
|
||||
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
Model string `json:"model,omitempty"` // 模型名称
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ func (zhipuResponse *ZhipuResponse) ResponseHandler(resp *http.Response) (OpenAI
|
||||
ID: zhipuResponse.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: zhipuResponse.Model,
|
||||
Choices: make([]types.ChatCompletionChoice, 0, len(zhipuResponse.Data.Choices)),
|
||||
Usage: &zhipuResponse.Data.Usage,
|
||||
}
|
||||
@@ -94,13 +95,15 @@ func (p *ZhipuProvider) ChatAction(request *types.ChatCompletionRequest, isModel
|
||||
}
|
||||
|
||||
if request.Stream {
|
||||
errWithCode, usage = p.sendStreamRequest(req)
|
||||
errWithCode, usage = p.sendStreamRequest(req, request.Model)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
zhipuResponse := &ZhipuResponse{}
|
||||
zhipuResponse := &ZhipuResponse{
|
||||
Model: request.Model,
|
||||
}
|
||||
errWithCode = p.SendRequest(req, zhipuResponse, false)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
@@ -132,13 +135,13 @@ func (p *ZhipuProvider) streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStrea
|
||||
ID: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Model: zhipuResponse.Model,
|
||||
Choices: []types.ChatCompletionStreamChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||
func (p *ZhipuProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, *types.Usage) {
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
@@ -159,7 +162,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 {
|
||||
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Contains(string(data), ":") {
|
||||
return i + 2, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
@@ -195,6 +198,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
response := p.streamResponseZhipu2OpenAI(data)
|
||||
response.Model = model
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
@@ -209,6 +213,7 @@ func (p *ZhipuProvider) sendStreamRequest(req *http.Request) (*types.OpenAIError
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
zhipuResponse.Model = model
|
||||
response, zhipuUsage := p.streamMetaResponseZhipu2OpenAI(&zhipuResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
|
||||
@@ -31,6 +31,7 @@ type ZhipuResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
@@ -38,6 +39,7 @@ type ZhipuStreamMetaResponse struct {
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
types.Usage `json:"usage"`
|
||||
Model string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
|
||||
Reference in New Issue
Block a user