mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-10 18:43:41 +08:00
fix: Refactor relay/channel, upgrade deps, improve request handling and error messages.
* Updated relay/channel/gemini package to use gin-gonic/gin for routing * Added timeouts, environment variable for proxy, and error handling for HTTP clients in relay/util/init.go * Improved error handling, URL path cases, and channel selection logic in middleware/distributor.go * Added Content-Type header, closed request bodies, and added context to requests in relay/channel/common.go * Upgraded various dependencies in go.mod * Modified the GetRequestURL method in relay/channel/gemini/adaptor.go to include a "key" parameter in the URL and set a default version of "v1beta" * Added io and net/http packages in relay/channel/gemini/adaptor.go and relay/channel/gemini/main.go * Added a struct for GeminiStreamResp and modified response handling in relay/channel/gemini/main.go * Imported packages for io and net/http added, gin-gonic/gin package added, and error handling improved in relay/channel/gemini/main.go
This commit is contained in:
@@ -22,10 +22,15 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get request url failed")
|
||||
}
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
|
||||
req, err := http.NewRequestWithContext(c.Request.Context(),
|
||||
c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "new request failed")
|
||||
}
|
||||
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
|
||||
err = a.SetupRequestHeader(c, req, meta)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "setup request header failed")
|
||||
@@ -47,5 +52,6 @@ func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
|
||||
}
|
||||
_ = req.Body.Close()
|
||||
_ = c.Request.Body.Close()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,9 @@ package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Laisky/errors/v2"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@@ -9,8 +12,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
@@ -21,17 +22,18 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1")
|
||||
version := helper.AssignOrDefault(meta.APIVersion, "v1beta")
|
||||
action := "generateContent"
|
||||
if meta.IsStream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
|
||||
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
|
||||
channelhelper.SetupCommonRequestHeader(c, req, meta)
|
||||
req.Header.Set("x-goog-api-key", meta.APIKey)
|
||||
req.URL.Query().Add("key", meta.APIKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
@@ -13,11 +16,6 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
@@ -99,8 +97,9 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(context.TODO(),
|
||||
fmt.Sprintf("send %d images to gemini-pro-vision", len(parts)))
|
||||
fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
@@ -197,73 +196,182 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
|
||||
return &response
|
||||
}
|
||||
|
||||
// [{
|
||||
// "candidates": [
|
||||
// {
|
||||
// "content": {
|
||||
// "parts": [
|
||||
// {
|
||||
// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
|
||||
// }
|
||||
// ],
|
||||
// "role": "model"
|
||||
// },
|
||||
// "finishReason": "STOP",
|
||||
// "index": 0,
|
||||
// "safetyRatings": [
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HARASSMENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// ],
|
||||
// "promptFeedback": {
|
||||
// "safetyRatings": [
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_HARASSMENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// },
|
||||
// {
|
||||
// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
// "probability": "NEGLIGIBLE"
|
||||
// }
|
||||
// ]
|
||||
// }
|
||||
// }]
|
||||
type GeminiStreamResp struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
} `json:"candidates"`
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
var respData []GeminiStreamResp
|
||||
if err = json.Unmarshal(respBody, &respData); err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
for _, chunk := range respData {
|
||||
for _, cad := range chunk.Candidates {
|
||||
for _, part := range cad.Content.Parts {
|
||||
responseText += part.Text
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = responseText
|
||||
resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: helper.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
})
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
|
||||
}
|
||||
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
|
||||
// dataChan := make(chan string)
|
||||
// stopChan := make(chan bool)
|
||||
// scanner := bufio.NewScanner(resp.Body)
|
||||
// scanner.Split(bufio.ScanLines)
|
||||
// // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
// // if atEOF && len(data) == 0 {
|
||||
// // return 0, nil, nil
|
||||
// // }
|
||||
// // if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
// // return i + 1, data[0:i], nil
|
||||
// // }
|
||||
// // if atEOF {
|
||||
// // return len(data), data, nil
|
||||
// // }
|
||||
// // return 0, nil, nil
|
||||
// // })
|
||||
// go func() {
|
||||
// var content string
|
||||
// for scanner.Scan() {
|
||||
// line := strings.TrimSpace(scanner.Text())
|
||||
// fmt.Printf("> gemini got line: %s\n", line)
|
||||
// content += line
|
||||
// // if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
// // continue
|
||||
// // }
|
||||
|
||||
// // data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
// // data = strings.TrimSuffix(data, "\"")
|
||||
// // dataChan <- data
|
||||
// }
|
||||
|
||||
// dataChan <- content
|
||||
// stopChan <- true
|
||||
// }()
|
||||
// common.SetEventStreamHeaders(c)
|
||||
// c.Stream(func(w io.Writer) bool {
|
||||
// select {
|
||||
// case data := <-dataChan:
|
||||
// // this is used to prevent annoying \ related format bug
|
||||
// data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
// type dummyStruct struct {
|
||||
// Content string `json:"content"`
|
||||
// }
|
||||
// var dummy dummyStruct
|
||||
// err := json.Unmarshal([]byte(data), &dummy)
|
||||
// responseText += dummy.Content
|
||||
// var choice openai.ChatCompletionsStreamResponseChoice
|
||||
// choice.Delta.Content = dummy.Content
|
||||
// response := openai.ChatCompletionsStreamResponse{
|
||||
// Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
|
||||
// Object: "chat.completion.chunk",
|
||||
// Created: helper.GetTimestamp(),
|
||||
// Model: "gemini-pro",
|
||||
// Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
// }
|
||||
// jsonResponse, err := json.Marshal(response)
|
||||
// if err != nil {
|
||||
// logger.SysError("error marshalling stream response: " + err.Error())
|
||||
// return true
|
||||
// }
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
// return true
|
||||
// case <-stopChan:
|
||||
// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
// return false
|
||||
// }
|
||||
// })
|
||||
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user