chore: remove helper & util subpackage for relay

This commit is contained in:
JustSong
2024-04-06 01:50:12 +08:00
parent 24ed170e7b
commit f586ae0ad8
19 changed files with 221 additions and 215 deletions

View File

@@ -1,4 +1,4 @@
package helper
package relay
import (
"github.com/songquanpeng/one-api/relay/adaptor"

View File

@@ -9,9 +9,9 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
@@ -305,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := util.ImpatientHTTPClient.Do(req)
res, err := client.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}

View File

@@ -4,8 +4,8 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -39,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}

View File

@@ -10,7 +10,6 @@ import (
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
@@ -43,11 +42,11 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
//https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case channeltype.Minimax:
return minimax.GetRequestURL(meta)
default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}
}

View File

@@ -1,6 +1,11 @@
package openai
import "github.com/songquanpeng/one-api/relay/model"
import (
"fmt"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
@@ -9,3 +14,17 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}

View File

@@ -1,4 +1,4 @@
package util
package client
import (
"github.com/songquanpeng/one-api/common/config"

View File

@@ -17,9 +17,9 @@ import (
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
@@ -125,7 +125,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
baseURL = c.GetString("base_url")
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == channeltype.Azure {
apiVersion := azure.GetAPIVersion(c)
if relayMode == relaymode.AudioTranscription {
@@ -162,7 +162,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@@ -215,7 +215,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota

91
relay/controller/error.go Normal file
View File

@@ -0,0 +1,91 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
)
type GeneralErrorResponse struct {
Error model.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) {
ErrorWithStatusCode = &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}

View File

@@ -12,10 +12,10 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller/validator"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"github.com/songquanpeng/one-api/relay/util"
"math"
"net/http"
)
@@ -32,7 +32,7 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(textRequest, relayMode)
err = validator.ValidateTextRequest(textRequest, relayMode)
if err != nil {
return nil, err
}
@@ -193,3 +193,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}

View File

@@ -9,13 +9,12 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -41,7 +40,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
@@ -66,7 +65,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = c.Request.Body
}
adaptor := helper.GetAdaptor(meta.APIType)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}

View File

@@ -6,15 +6,14 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
@@ -34,7 +33,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model)
@@ -49,7 +48,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return bizErr
}
adaptor := helper.GetAdaptor(meta.APIType)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
@@ -90,7 +89,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")

View File

@@ -1,4 +1,4 @@
package util
package validator
import (
"errors"

View File

@@ -1,161 +0,0 @@
package util
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channeltype"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
)
func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
type GeneralErrorResponse struct {
Error relaymodel.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: relaymodel.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}

View File

@@ -1,12 +0,0 @@
package util
func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}