mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 13:43:42 +08:00
✨ feat: support other OpenAI APIs (#165)
* ✨ feat: support other OpenAI APIs * 🔖 chore: Update English translation
This commit is contained in:
@@ -78,7 +78,8 @@ func GetProvider(c *gin.Context, modeName string) (provider providersBase.Provid
|
||||
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
ignore := c.GetBool("specific_channel_id_ignore")
|
||||
if channelId > 0 && !ignore {
|
||||
return fetchChannelById(channelId)
|
||||
}
|
||||
|
||||
@@ -206,7 +207,8 @@ func responseCache(c *gin.Context, response string) {
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
ignore := c.GetBool("specific_channel_id_ignore")
|
||||
if channelId > 0 && !ignore {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
@@ -230,3 +232,11 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
|
||||
controller.DisableChannel(channelId, channelName, err.Message, true)
|
||||
}
|
||||
}
|
||||
|
||||
func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -75,15 +75,10 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
|
||||
if apiErr != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
if apiErr.StatusCode == http.StatusTooManyRequests {
|
||||
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
|
||||
c.JSON(apiErr.StatusCode, gin.H{
|
||||
"error": apiErr.OpenAIError,
|
||||
})
|
||||
|
||||
relayResponseWithErr(c, apiErr)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
85
relay/relay.go
Normal file
85
relay/relay.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers/azure"
|
||||
"one-api/providers/openai"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayOnly(c *gin.Context) {
|
||||
provider, _, fail := GetProvider(c, "")
|
||||
if fail != nil {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, fail.Error())
|
||||
return
|
||||
}
|
||||
|
||||
channel := provider.GetChannel()
|
||||
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取请求的path
|
||||
url := ""
|
||||
path := c.Request.URL.Path
|
||||
openAIProvider, ok := provider.(*openai.OpenAIProvider)
|
||||
if !ok {
|
||||
azureProvider, ok := provider.(*azure.AzureProvider)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type openai")
|
||||
return
|
||||
}
|
||||
url = azureProvider.GetFullRequestURL(path, "")
|
||||
} else {
|
||||
url = openAIProvider.GetFullRequestURL(path, "")
|
||||
}
|
||||
|
||||
headers := c.Request.Header
|
||||
mapHeaders := provider.GetRequestHeaders()
|
||||
// 设置请求头
|
||||
for k, v := range headers {
|
||||
if _, ok := mapHeaders[k]; ok {
|
||||
continue
|
||||
}
|
||||
mapHeaders[k] = strings.Join(v, ", ")
|
||||
}
|
||||
|
||||
requester := provider.GetRequester()
|
||||
req, err := requester.NewRequest(c.Request.Method, url, requester.WithBody(c.Request.Body), requester.WithHeader(mapHeaders))
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
defer req.Body.Close()
|
||||
|
||||
response, errWithCode := requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
relayResponseWithErr(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
errWithCode = responseMultipart(c, response)
|
||||
|
||||
if errWithCode != nil {
|
||||
relayResponseWithErr(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
requestTime := 0
|
||||
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
|
||||
if requestStartTimeValue != nil {
|
||||
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
||||
if ok {
|
||||
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
||||
}
|
||||
}
|
||||
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime)
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user