fix: Refactor relay/channel, upgrade deps, improve request handling and error messages.

* Updated relay/channel/gemini package to use gin-gonic/gin for routing
* Added timeouts, environment variable for proxy, and error handling for HTTP clients in relay/util/init.go
* Improved error handling, URL path cases, and channel selection logic in middleware/distributor.go
* Added Content-Type header, closed request bodies, and added context to requests in relay/channel/common.go
* Upgraded various dependencies in go.mod
* Modified the GetRequestURL method in relay/channel/gemini/adaptor.go to include a "key" parameter in the URL and set a default version of "v1beta"
* Added io and net/http packages in relay/channel/gemini/adaptor.go and relay/channel/gemini/main.go
* Added a struct for GeminiStreamResp and modified response handling in relay/channel/gemini/main.go
* Imported packages for io and net/http added, gin-gonic/gin package added, and error handling improved in relay/channel/gemini/main.go
This commit is contained in:
Laisky.Cai
2024-03-19 03:11:19 +00:00
parent d379377eca
commit ddd2dd1041
8 changed files with 371 additions and 184 deletions

View File

@@ -22,10 +22,15 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
if err != nil {
return nil, errors.Wrap(err, "get request url failed")
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
req, err := http.NewRequestWithContext(c.Request.Context(),
c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, errors.Wrap(err, "new request failed")
}
req.Header.Add("Content-Type", "application/json")
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, errors.Wrap(err, "setup request header failed")
@@ -47,5 +52,6 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -2,6 +2,9 @@ package gemini
import (
"fmt"
"io"
"net/http"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
@@ -9,8 +12,6 @@ import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
@@ -21,17 +22,18 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
version := helper.AssignOrDefault(meta.APIVersion, "v1beta")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
req.URL.Query().Add("key", meta.APIKey)
return nil
}

View File

@@ -1,10 +1,13 @@
package gemini
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
@@ -13,11 +16,6 @@ import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
@@ -99,8 +97,9 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
})
}
}
logger.Info(context.TODO(),
fmt.Sprintf("send %d images to gemini-pro-vision", len(parts)))
fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
@@ -197,73 +196,182 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
// [{
// "candidates": [
// {
// "content": {
// "parts": [
// {
// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
// }
// ],
// "role": "model"
// },
// "finishReason": "STOP",
// "index": 0,
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// ],
// "promptFeedback": {
// "safetyRatings": [
// {
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HATE_SPEECH",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_HARASSMENT",
// "probability": "NEGLIGIBLE"
// },
// {
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
// "probability": "NEGLIGIBLE"
// }
// ]
// }
// }]
type GeminiStreamResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
} `json:"candidates"`
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
}
var respData []GeminiStreamResp
if err = json.Unmarshal(respBody, &respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
}
for _, chunk := range respData {
for _, cad := range chunk.Candidates {
for _, part := range cad.Content.Parts {
responseText += part.Text
}
}
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText
resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
})
if err != nil {
return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// dataChan := make(chan string)
// stopChan := make(chan bool)
// scanner := bufio.NewScanner(resp.Body)
// scanner.Split(bufio.ScanLines)
// // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
// // if atEOF && len(data) == 0 {
// // return 0, nil, nil
// // }
// // if i := strings.Index(string(data), "\n"); i >= 0 {
// // return i + 1, data[0:i], nil
// // }
// // if atEOF {
// // return len(data), data, nil
// // }
// // return 0, nil, nil
// // })
// go func() {
// var content string
// for scanner.Scan() {
// line := strings.TrimSpace(scanner.Text())
// fmt.Printf("> gemini got line: %s\n", line)
// content += line
// // if !strings.HasPrefix(data, "\"text\": \"") {
// // continue
// // }
// // data = strings.TrimPrefix(data, "\"text\": \"")
// // data = strings.TrimSuffix(data, "\"")
// // dataChan <- data
// }
// dataChan <- content
// stopChan <- true
// }()
// common.SetEventStreamHeaders(c)
// c.Stream(func(w io.Writer) bool {
// select {
// case data := <-dataChan:
// // this is used to prevent annoying \ related format bug
// data = fmt.Sprintf("{\"content\": \"%s\"}", data)
// type dummyStruct struct {
// Content string `json:"content"`
// }
// var dummy dummyStruct
// err := json.Unmarshal([]byte(data), &dummy)
// responseText += dummy.Content
// var choice openai.ChatCompletionsStreamResponseChoice
// choice.Delta.Content = dummy.Content
// response := openai.ChatCompletionsStreamResponse{
// Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
// Object: "chat.completion.chunk",
// Created: helper.GetTimestamp(),
// Model: "gemini-pro",
// Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
// }
// jsonResponse, err := json.Marshal(response)
// if err != nil {
// logger.SysError("error marshalling stream response: " + err.Error())
// return true
// }
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
// return true
// case <-stopChan:
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
// return false
// }
// })
if err := resp.Body.Close(); err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}