mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-29 22:56:39 +08:00
86 lines
2.2 KiB
Go
86 lines
2.2 KiB
Go
package relay
|
|
|
|
import (
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/model"
|
|
"one-api/providers/azure"
|
|
"one-api/providers/openai"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func RelayOnly(c *gin.Context) {
|
|
provider, _, fail := GetProvider(c, "")
|
|
if fail != nil {
|
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, fail.Error())
|
|
return
|
|
}
|
|
|
|
channel := provider.GetChannel()
|
|
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure {
|
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
|
|
return
|
|
}
|
|
|
|
// 获取请求的path
|
|
url := ""
|
|
path := c.Request.URL.Path
|
|
openAIProvider, ok := provider.(*openai.OpenAIProvider)
|
|
if !ok {
|
|
azureProvider, ok := provider.(*azure.AzureProvider)
|
|
if !ok {
|
|
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type openai")
|
|
return
|
|
}
|
|
url = azureProvider.GetFullRequestURL(path, "")
|
|
} else {
|
|
url = openAIProvider.GetFullRequestURL(path, "")
|
|
}
|
|
|
|
headers := c.Request.Header
|
|
mapHeaders := provider.GetRequestHeaders()
|
|
// 设置请求头
|
|
for k, v := range headers {
|
|
if _, ok := mapHeaders[k]; ok {
|
|
continue
|
|
}
|
|
mapHeaders[k] = strings.Join(v, ", ")
|
|
}
|
|
|
|
requester := provider.GetRequester()
|
|
req, err := requester.NewRequest(c.Request.Method, url, requester.WithBody(c.Request.Body), requester.WithHeader(mapHeaders))
|
|
if err != nil {
|
|
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
|
return
|
|
}
|
|
|
|
defer req.Body.Close()
|
|
|
|
response, errWithCode := requester.SendRequestRaw(req)
|
|
if errWithCode != nil {
|
|
relayResponseWithErr(c, errWithCode)
|
|
return
|
|
}
|
|
|
|
errWithCode = responseMultipart(c, response)
|
|
|
|
if errWithCode != nil {
|
|
relayResponseWithErr(c, errWithCode)
|
|
return
|
|
}
|
|
|
|
requestTime := 0
|
|
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
|
|
if requestStartTimeValue != nil {
|
|
requestStartTime, ok := requestStartTimeValue.(time.Time)
|
|
if ok {
|
|
requestTime = int(time.Since(requestStartTime).Milliseconds())
|
|
}
|
|
}
|
|
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime)
|
|
|
|
}
|