Merge branch 'songquanpeng' into sync_upstream

This commit is contained in:
Martial BE
2023-12-25 11:23:28 +08:00
24 changed files with 251 additions and 49 deletions

View File

@@ -32,7 +32,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
version = p.Context.GetString("api_version")
}
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", baseURL, version, modelName, requestURL, p.Context.GetString("api_key"))
return fmt.Sprintf("%s/%s/models/%s:%s", baseURL, version, modelName, requestURL)
}
@@ -40,6 +40,7 @@ func (p *GeminiProvider) GetFullRequestURL(requestURL string, modelName string)
func (p *GeminiProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["x-goog-api-key"] = p.Context.GetString("api_key")
return headers
}

View File

@@ -7,11 +7,16 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"one-api/providers/base"
"one-api/types"
"strings"
)
const (
GeminiVisionMaxImageNum = 16
)
func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if len(response.Candidates) == 0 {
return nil, &types.OpenAIErrorWithStatusCode{
@@ -29,6 +34,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: response.Model,
Choices: make([]types.ChatCompletionChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
@@ -46,7 +52,7 @@ func (response *GeminiChatResponse) ResponseHandler(resp *http.Response) (OpenAI
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
completionTokens := common.CountTokenText(response.GetResponseText(), "gemini-pro")
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
response.Usage.CompletionTokens = completionTokens
response.Usage.TotalTokens = response.Usage.PromptTokens + completionTokens
@@ -98,6 +104,31 @@ func (p *GeminiProvider) getChatRequestBody(request *types.ChatCompletionRequest
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == types.ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == types.ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.URL)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
@@ -142,7 +173,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
if request.Stream {
var responseText string
errWithCode, responseText = p.sendStreamRequest(req)
errWithCode, responseText = p.sendStreamRequest(req, request.Model)
if errWithCode != nil {
return
}
@@ -155,6 +186,7 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
} else {
var geminiResponse = &GeminiChatResponse{
Model: request.Model,
Usage: &types.Usage{
PromptTokens: promptTokens,
},
@@ -170,18 +202,18 @@ func (p *GeminiProvider) ChatAction(request *types.ChatCompletionRequest, isMode
}
func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &base.StopFinishReason
var response types.ChatCompletionStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []types.ChatCompletionStreamChoice{choice}
return &response
}
// func (p *GeminiProvider) streamResponseClaude2OpenAI(geminiResponse *GeminiChatResponse) *types.ChatCompletionStreamResponse {
// var choice types.ChatCompletionStreamChoice
// choice.Delta.Content = geminiResponse.GetResponseText()
// choice.FinishReason = &base.StopFinishReason
// var response types.ChatCompletionStreamResponse
// response.Object = "chat.completion.chunk"
// response.Model = "gemini"
// response.Choices = []types.ChatCompletionStreamChoice{choice}
// return &response
// }
func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErrorWithStatusCode, string) {
func (p *GeminiProvider) sendStreamRequest(req *http.Request, model string) (*types.OpenAIErrorWithStatusCode, string) {
defer req.Body.Close()
// 发送请求
@@ -235,7 +267,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice types.ChatCompletionStreamChoice
choice.Delta.Content = dummy.Content
@@ -243,7 +275,7 @@ func (p *GeminiProvider) sendStreamRequest(req *http.Request) (*types.OpenAIErro
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "gemini-pro",
Model: model,
Choices: []types.ChatCompletionStreamChoice{choice},
}
jsonResponse, err := json.Marshal(response)

View File

@@ -46,6 +46,7 @@ type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
Usage *types.Usage `json:"usage,omitempty"`
Model string `json:"model,omitempty"`
}
type GeminiChatCandidate struct {