one-api/relay/channel/gemini/main.go
Laisky.Cai ddd2dd1041 fix: Refactor relay/channel, upgrade deps, improve request handling and error messages.
* Updated relay/channel/gemini package to use gin-gonic/gin for routing
* Added timeouts, environment variable for proxy, and error handling for HTTP clients in relay/util/init.go
* Improved error handling, URL path cases, and channel selection logic in middleware/distributor.go
* Added Content-Type header, closed request bodies, and added context to requests in relay/channel/common.go
* Upgraded various dependencies in go.mod
* Modified the GetRequestURL method in relay/channel/gemini/adaptor.go to include a "key" parameter in the URL and set a default version of "v1beta"
* Added io and net/http packages in relay/channel/gemini/adaptor.go and relay/channel/gemini/main.go
* Added a struct for GeminiStreamResp and modified response handling in relay/channel/gemini/main.go
* Imported packages for io and net/http added, gin-gonic/gin package added, and error handling improved in relay/channel/gemini/main.go
2024-03-19 03:11:19 +00:00

421 lines
12 KiB
Go

package gemini
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := ChatContent{
Role: message.Role,
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == model.ContentTypeImageURL {
imageNum += 1
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
if err != nil {
logger.Warn(context.TODO(),
fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
continue
}
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
logger.Info(context.TODO(),
fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []Part{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: model.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
// [{
// "candidates": [
// {
// "content": {
// "parts": [
// {
// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
// }
// ],
// "role": "model"
// },
// "finishReason": "STOP",
// "index": 0,
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// ],
// "promptFeedback": {
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// }]
type GeminiStreamResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
} `json:"candidates"`
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
}
var respData []GeminiStreamResp
if err = json.Unmarshal(respBody, &respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
}
for _, chunk := range respData {
for _, cad := range chunk.Candidates {
for _, part := range cad.Content.Parts {
responseText += part.Text
}
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
})
if err != nil {
return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// dataChan := make(chan string)
// stopChan := make(chan bool)
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(bufio.ScanLines)
// // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// // if atEOF && len(data) == 0 {
// // return 0, nil, nil
// // }
// // if i := strings.Index(string(data), "\n"); i >= 0 {
// // return i + 1, data[0:i], nil
// // }
// // if atEOF {
// // return len(data), data, nil
// // }
// // return 0, nil, nil
// // })
// go func() {
// var content string
// for scanner.Scan() {
// line := strings.TrimSpace(scanner.Text())
// fmt.Printf("> gemini got line: %s\n", line)
// content += line
// // if !strings.HasPrefix(data, "\"text\": \"") {
// // continue
// // }
// // data = strings.TrimPrefix(data, "\"text\": \"")
// // data = strings.TrimSuffix(data, "\"")
// // dataChan <- data
// }
// dataChan <- content
// stopChan <- true
// }()
// common.SetEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// // this is used to prevent annoying \ related format bug
// data = fmt.Sprintf("{\"content\": \"%s\"}", data)
// type dummyStruct struct {
// Content string `json:"content"`
// }
// var dummy dummyStruct
// err := json.Unmarshal([]byte(data), &dummy)
// responseText += dummy.Content
// var choice openai.ChatCompletionsStreamResponseChoice
// choice.Delta.Content = dummy.Content
// response := openai.ChatCompletionsStreamResponse{
// Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
// Object: "chat.completion.chunk",
// Created: helper.GetTimestamp(),
// Model: "gemini-pro",
// Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
// }
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// logger.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
if err := resp.Body.Close(); err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}