From e057c0e42e59e2db746370fc4a4e348ed3dbd459 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 18 Dec 2023 23:45:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0gemini=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/constants.go | 1 + common/model-ratio.go | 1 + controller/channel-test.go | 2 + controller/model.go | 9 + controller/relay-gemini.go | 307 +++++++++++++++++++++++++ controller/relay-text.go | 46 ++++ middleware/distributor.go | 2 + web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 3 + 9 files changed, 372 insertions(+) create mode 100644 controller/relay-gemini.go diff --git a/common/constants.go b/common/constants.go index 939b62f..50b7091 100644 --- a/common/constants.go +++ b/common/constants.go @@ -190,6 +190,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ diff --git a/common/model-ratio.go b/common/model-ratio.go index 0aa6aa6..46505fe 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -61,6 +61,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index c441a48..2f94fdf 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -29,6 +29,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai fallthrough case common.ChannelType360: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeXunfei: return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: diff --git a/controller/model.go b/controller/model.go index 563dd40..9fa2132 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 0000000..f68d8c1 --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,307 @@ +package controller + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + "strings" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: string(message.Content), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "Okay", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +func (g *GeminiChatResponse) GetResponseText() string { + if g == nil { + return "" + } + if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 { + return g.Candidates[0].Content.Parts[0].Text + } + return "" +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + content, _ := json.Marshal("") + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: content, + }, + FinishReason: stopFinishReason, + } + content, _ = json.Marshal(candidate.Content.Parts[0].Text) + if len(candidate.Content.Parts) > 0 { + choice.Message.Content = content + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = geminiResponse.GetResponseText() + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + dataChan := make(chan string) + stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + go func() { + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data + } + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion.chunk", + Created: common.GetTimestamp(), + Model: "gemini-pro", + Choices: []ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.GetResponseText(), model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index 93a0ab6..6231786 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -26,6 +26,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -119,6 +120,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -180,6 +183,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + if textRequest.Stream { + action = "streamGenerateContent" + } + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -280,6 +300,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -539,6 +566,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/middleware/distributor.go b/middleware/distributor.go index 88a402c..8446c7e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -107,6 +107,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) } c.Next() } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 82d21c6..8a021af 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black', label: 'Anthropic Claude' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive', label: 'Azure OpenAI' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange', label: 'Google PaLM2' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue', label: '百度文心千帆' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange', label: '阿里通义千问' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue', label: '讯飞星火认知' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 5cae807..f745ccf 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -86,6 +86,9 @@ const EditChannel = (props) => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({...inputs, models: localModels})); }