refactor: use adaptor to do relay & test

This commit is contained in:
JustSong
2024-02-18 00:15:31 +08:00
parent d548a01c59
commit 1aa374ccfb
63 changed files with 1452 additions and 1332 deletions

View File

@@ -14,13 +14,14 @@ import (
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")

View File

@@ -11,14 +11,14 @@ import (
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"math"
"net/http"
)
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) {
textRequest := &openai.GeneralOpenAIRequest{}
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
textRequest := &relaymodel.GeneralOpenAIRequest{}
err := common.UnmarshalBodyReusable(c, textRequest)
if err != nil {
return nil, err
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOp
return textRequest, nil
}
func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
@@ -48,7 +48,7 @@ func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) in
return 0
}
func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
@@ -56,7 +56,7 @@ func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens
return int(float64(preConsumedTokens) * ratio)
}
func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) {
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(meta.UserId)
@@ -85,7 +85,7 @@ func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIReque
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
@@ -120,27 +120,3 @@ func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.Relay
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
}
func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) {
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, err
}
SetupRequestHeaders(c, req, meta, isStream)
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
err = req.Body.Close()
if err != nil {
logger.Warnf(ctx, "close req.Body failed: %+v", err)
}
err = c.Request.Body.Close()
if err != nil {
logger.Warnf(ctx, "close c.Request.Body failed: %+v", err)
}
return resp, nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
@@ -28,7 +29,7 @@ func isWithinRange(element string, value int) bool {
return value >= min && value <= max
}
func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
imageModel := "dall-e-2"
imageSize := "1024x1024"

View File

@@ -12,19 +12,21 @@ import (
"github.com/songquanpeng/one-api/relay/channel/ali"
"github.com/songquanpeng/one-api/relay/channel/anthropic"
"github.com/songquanpeng/one-api/relay/channel/baidu"
"github.com/songquanpeng/one-api/relay/channel/google"
"github.com/songquanpeng/one-api/relay/channel/gemini"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/palm"
"github.com/songquanpeng/one-api/relay/channel/tencent"
"github.com/songquanpeng/one-api/relay/channel/xunfei"
"github.com/songquanpeng/one-api/relay/channel/zhipu"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *model.GeneralOpenAIRequest) (string, error) {
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
switch meta.APIType {
case constant.APITypeOpenAI:
@@ -43,7 +45,7 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
case constant.APITypeBaidu:
switch textRequest.Model {
@@ -92,19 +94,10 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
return fullRequestURL, nil
}
func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case constant.APITypeClaude:
case constant.APITypeAnthropic:
claudeRequest := anthropic.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
@@ -127,14 +120,14 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
}
requestBody = bytes.NewBuffer(jsonData)
case constant.APITypePaLM:
palmRequest := google.ConvertPaLMRequest(textRequest)
palmRequest := palm.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
case constant.APITypeGemini:
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
geminiChatRequest := gemini.ConvertRequest(textRequest)
jsonStr, err := json.Marshal(geminiChatRequest)
if err != nil {
return nil, err
@@ -187,19 +180,20 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
default:
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
}
return requestBody, nil
}
func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
SetupAuthHeaders(c, req, meta, isStream)
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if isStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
apiKey := meta.APIKey
switch meta.APIType {
@@ -213,7 +207,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
req.Header.Set("X-Title", "One API")
}
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
@@ -242,7 +236,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
}
}
func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
func DoResponse(c *gin.Context, textRequest *model.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *model.Usage, err *model.ErrorWithStatusCode) {
var responseText string
switch apiType {
case constant.APITypeOpenAI:
@@ -251,7 +245,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
} else {
err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeClaude:
case constant.APITypeAnthropic:
if isStream {
err, responseText = anthropic.StreamHandler(c, resp)
} else {
@@ -270,15 +264,15 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
}
case constant.APITypePaLM:
if isStream { // PaLM2 API does not support stream
err, responseText = google.PaLMStreamHandler(c, resp)
err, responseText = palm.StreamHandler(c, resp)
} else {
err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
err, usage = palm.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeGemini:
if isStream {
err, responseText = google.StreamHandler(c, resp)
err, responseText = gemini.StreamHandler(c, resp)
} else {
err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
err, usage = gemini.Handler(c, resp, promptTokens, textRequest.Model)
}
case constant.APITypeZhipu:
if isStream {
@@ -328,7 +322,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
return nil, err
}
if usage == nil && responseText != "" {
usage = &openai.Usage{}
usage = &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens

View File

@@ -1,18 +1,23 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
)
func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
// get & validate textRequest
@@ -21,9 +26,13 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
}
meta.IsStream = textRequest.Stream
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(meta.Group)
@@ -36,35 +45,50 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
return bizErr
}
adaptor := helper.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
// get request body
requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode)
if err != nil {
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI {
// no need to convert request for openai
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonData)
}
// do request
var resp *http.Response
isStream := textRequest.Stream
if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest)
if err != nil {
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
}
resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody)
if err != nil {
logger.Errorf(ctx, "doRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if resp.StatusCode != http.StatusOK {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
}
// do response
usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens)
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)