Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2023-12-27 02:14:34 +00:00
14 changed files with 218 additions and 31 deletions

View File

@@ -432,6 +432,15 @@ func init() {
Root: "gemini-pro",
Parent: nil,
},
{
Id: "gemini-pro-vision",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro-vision",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",

View File

@@ -7,11 +7,18 @@ import (
"io"
"net/http"
"one-api/common"
"one-api/common/image"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
GeminiVisionMaxImageNum = 16
)
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
@@ -97,6 +104,30 @@ func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest {
},
},
}
openaiContent := message.ParseContent()
var parts []GeminiPart
imageNum := 0
for _, part := range openaiContent {
if part.Type == ContentTypeText {
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == ContentTypeImageURL {
imageNum += 1
if imageNum > GeminiVisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"

View File

@@ -187,6 +187,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
fullTextResponse.Model = model
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{
PromptTokens: promptTokens,

View File

@@ -180,9 +180,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeGemini:
requestBaseURL := "https://generativelanguage.googleapis.com"
if baseURL != "" {
@@ -200,21 +197,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
// case APITypeZhipu:
// method := "invoke"
// if textRequest.Stream {
// method = "sse-invoke"
// }
// fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
// case APITypeAli:
// fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
// if relayMode == RelayModeEmbeddings {
// fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
// }
// case APITypeTencent:
// fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
// case APITypeAIProxyLibrary:
// fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
case APITypeAli:
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
if relayMode == RelayModeEmbeddings {
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
}
case APITypeTencent:
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
case APITypeAIProxyLibrary:
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
}
var promptTokens int
var completionTokens int
@@ -410,9 +407,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// case APITypeTencent:
// req.Header.Set("Authorization", apiKey)
case APITypePaLM:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
case APITypeGemini:
// do not set Authorization header
req.Header.Set("x-goog-api-key", apiKey)
default:
req.Header.Set("Authorization", "Bearer "+apiKey)
}

View File

@@ -69,6 +69,22 @@ type ImageContent struct {
ImageURL *ImageURL `json:"image_url,omitempty"`
}
const (
ContentTypeText = "text"
ContentTypeImageURL = "image_url"
)
type OpenAIMessageContent struct {
Type string `json:"type,omitempty"`
Text string `json:"text"`
ImageURL *ImageURL `json:"image_url,omitempty"`
}
func (m Message) IsStringContent() bool {
_, ok := m.Content.(string)
return ok
}
func (m Message) StringContent() string {
content, ok := m.Content.(string)
if ok {
@@ -82,7 +98,7 @@ func (m Message) StringContent() string {
if !ok {
continue
}
if contentMap["type"] == "text" {
if contentMap["type"] == ContentTypeText {
if subStr, ok := contentMap["text"].(string); ok {
contentStr += subStr
}
@@ -93,6 +109,47 @@ func (m Message) StringContent() string {
return ""
}
func (m Message) ParseContent() []OpenAIMessageContent {
var contentList []OpenAIMessageContent
content, ok := m.Content.(string)
if ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: content,
})
return contentList
}
anyList, ok := m.Content.([]any)
if ok {
for _, contentItem := range anyList {
contentMap, ok := contentItem.(map[string]any)
if !ok {
continue
}
switch contentMap["type"] {
case ContentTypeText:
if subStr, ok := contentMap["text"].(string); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeText,
Text: subStr,
})
}
case ContentTypeImageURL:
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
contentList = append(contentList, OpenAIMessageContent{
Type: ContentTypeImageURL,
ImageURL: &ImageURL{
Url: subObj["url"].(string),
},
})
}
}
}
return contentList
}
return nil
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
@@ -281,7 +338,7 @@ type OpenAITextResponseChoice struct {
type OpenAITextResponse struct {
Id string `json:"id"`
Model string `json:"model"`
Model string `json:"model,omitempty"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`