mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 05:33:42 +08:00
✨ feat: channel support weight (#85)
* ✨ feat: channel support weight * 💄 improve: show version * 💄 improve: Channel add copy operation * 💄 improve: Channel support batch add
This commit is contained in:
@@ -122,7 +122,7 @@ func updateAllChannelsBalance() error {
|
||||
} else {
|
||||
// err is nil & balance <= 0 means quota is used up
|
||||
if balance <= 0 {
|
||||
disableChannel(channel.Id, channel.Name, "余额不足")
|
||||
DisableChannel(channel.Id, channel.Name, "余额不足")
|
||||
}
|
||||
}
|
||||
time.Sleep(common.RequestInterval)
|
||||
|
||||
@@ -140,14 +140,6 @@ func notifyRootUser(subject string, content string) {
|
||||
}
|
||||
}
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
// enable & notify
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
@@ -185,10 +177,10 @@ func testAllChannels(notify bool) error {
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if milliseconds > disableThreshold {
|
||||
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
if isChannelEnabled && ShouldDisableChannel(openaiErr, -1) {
|
||||
DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
|
||||
75
controller/common.go
Normal file
75
controller/common.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ShouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
|
||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// disable & notify
|
||||
func DisableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := types.OpenAIError{
|
||||
Message: "API not implemented",
|
||||
Type: "one_api_error",
|
||||
Param: "",
|
||||
Code: "api_not_implemented",
|
||||
}
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := types.OpenAIError{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
Code: "",
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func ListModels(c *gin.Context) {
|
||||
groupName = user.Group
|
||||
}
|
||||
|
||||
models, err := model.CacheGetGroupModels(groupName)
|
||||
models, err := model.ChannelGroup.GetGroupModels(groupName)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||
return
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayChat(c *gin.Context) {
|
||||
|
||||
var chatRequest types.ChatCompletionRequest
|
||||
if err := common.UnmarshalBodyReusable(c, &chatRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, chatRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
chatRequest.Model = modelName
|
||||
|
||||
chatProvider, ok := provider.(providersBase.ChatInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
if chatRequest.Stream {
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient(c, response)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
}
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayCompletions(c *gin.Context) {
|
||||
|
||||
var completionRequest types.CompletionRequest
|
||||
if err := common.UnmarshalBodyReusable(c, &completionRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, completionRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
completionRequest.Model = modelName
|
||||
|
||||
completionProvider, ok := provider.(providersBase.CompletionInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
if completionRequest.Stream {
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, errWithCode = completionProvider.CreateCompletionStream(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseStreamClient(c, response)
|
||||
} else {
|
||||
var response *types.CompletionResponse
|
||||
response, errWithCode = completionProvider.CreateCompletion(&completionRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
}
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayEmbeddings(c *gin.Context) {
|
||||
|
||||
var embeddingsRequest types.EmbeddingRequest
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
embeddingsRequest.Model = c.Param("model")
|
||||
}
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &embeddingsRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
embeddingsRequest.Model = modelName
|
||||
|
||||
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,79 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayImageEdits(c *gin.Context) {
|
||||
|
||||
var imageEditRequest types.ImageEditRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if imageEditRequest.Prompt == "" {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "field prompt is required")
|
||||
return
|
||||
}
|
||||
|
||||
if imageEditRequest.Model == "" {
|
||||
imageEditRequest.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if imageEditRequest.Size == "" {
|
||||
imageEditRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageEditRequest.Model = modelName
|
||||
|
||||
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayImageGenerations(c *gin.Context) {
|
||||
|
||||
var imageRequest types.ImageRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &imageRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if imageRequest.Model == "" {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
if imageRequest.Quality == "" {
|
||||
imageRequest.Quality = "standard"
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, imageRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageRequest.Model = modelName
|
||||
|
||||
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens, err := common.CountTokenImage(imageRequest)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayImageVariations(c *gin.Context) {
|
||||
|
||||
var imageEditRequest types.ImageEditRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &imageEditRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if imageEditRequest.Model == "" {
|
||||
imageEditRequest.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if imageEditRequest.Size == "" {
|
||||
imageEditRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
imageEditRequest.Model = modelName
|
||||
|
||||
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens, err := common.CountTokenImage(imageEditRequest)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayModerations(c *gin.Context) {
|
||||
|
||||
var moderationRequest types.ModerationRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &moderationRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if moderationRequest.Model == "" {
|
||||
moderationRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, moderationRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
moderationRequest.Model = modelName
|
||||
|
||||
moderationProvider, ok := provider.(providersBase.ModerationInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseJsonClient(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelaySpeech(c *gin.Context) {
|
||||
|
||||
var speechRequest types.SpeechAudioRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &speechRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, speechRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
speechRequest.Model = modelName
|
||||
|
||||
speechProvider, ok := provider.(providersBase.SpeechInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := len(speechRequest.Input)
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseMultipart(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayTranscriptions(c *gin.Context) {
|
||||
|
||||
var audioRequest types.AudioRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
audioRequest.Model = modelName
|
||||
|
||||
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := 0
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseCustom(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayTranslations(c *gin.Context) {
|
||||
|
||||
var audioRequest types.AudioRequest
|
||||
|
||||
if err := common.UnmarshalBodyReusable(c, &audioRequest); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 获取供应商
|
||||
provider, modelName, fail := getProvider(c, audioRequest.Model)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
audioRequest.Model = modelName
|
||||
|
||||
translationProvider, ok := provider.(providersBase.TranslationInterface)
|
||||
if !ok {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取Input Tokens
|
||||
promptTokens := 0
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
provider.SetUsage(usage)
|
||||
|
||||
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
|
||||
if errWithCode != nil {
|
||||
errorHelper(c, errWithCode)
|
||||
return
|
||||
}
|
||||
errWithCode = responseCustom(c, response)
|
||||
|
||||
// 如果报错,则退还配额
|
||||
if errWithCode != nil {
|
||||
quotaInfo.undo(c, errWithCode)
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(c, usage)
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := types.OpenAIError{
|
||||
Message: "API not implemented",
|
||||
Type: "one_api_error",
|
||||
Param: "",
|
||||
Code: "api_not_implemented",
|
||||
}
|
||||
c.JSON(http.StatusNotImplemented, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := types.OpenAIError{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
Code: "",
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"error": err,
|
||||
})
|
||||
}
|
||||
|
||||
func errorHelper(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelName := c.GetString("channel_name")
|
||||
disableChannel(channelId, channelName, err.Message)
|
||||
}
|
||||
}
|
||||
53
controller/relay/base.go
Normal file
53
controller/relay/base.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"one-api/types"
|
||||
|
||||
providersBase "one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayBase struct {
|
||||
c *gin.Context
|
||||
provider providersBase.ProviderInterface
|
||||
originalModel string
|
||||
modelName string
|
||||
}
|
||||
|
||||
type RelayBaseInterface interface {
|
||||
send() (err *types.OpenAIErrorWithStatusCode, done bool)
|
||||
getPromptTokens() (int, error)
|
||||
setRequest() error
|
||||
setProvider(modelName string) error
|
||||
getProvider() providersBase.ProviderInterface
|
||||
getOriginalModel() string
|
||||
getModelName() string
|
||||
getContext() *gin.Context
|
||||
}
|
||||
|
||||
func (r *relayBase) setProvider(modelName string) error {
|
||||
provider, modelName, fail := getProvider(r.c, modelName)
|
||||
if fail != nil {
|
||||
return fail
|
||||
}
|
||||
r.provider = provider
|
||||
r.modelName = modelName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayBase) getContext() *gin.Context {
|
||||
return r.c
|
||||
}
|
||||
|
||||
func (r *relayBase) getProvider() providersBase.ProviderInterface {
|
||||
return r.provider
|
||||
}
|
||||
|
||||
func (r *relayBase) getOriginalModel() string {
|
||||
return r.originalModel
|
||||
}
|
||||
|
||||
func (r *relayBase) getModelName() string {
|
||||
return r.modelName
|
||||
}
|
||||
76
controller/relay/chat.go
Normal file
76
controller/relay/chat.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayChat struct {
|
||||
relayBase
|
||||
chatRequest types.ChatCompletionRequest
|
||||
}
|
||||
|
||||
func NewRelayChat(c *gin.Context) *relayChat {
|
||||
relay := &relayChat{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayChat) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.chatRequest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.chatRequest.MaxTokens < 0 || r.chatRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return errors.New("max_tokens is invalid")
|
||||
}
|
||||
|
||||
r.originalModel = r.chatRequest.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayChat) getPromptTokens() (int, error) {
|
||||
return common.CountTokenMessages(r.chatRequest.Messages, r.modelName), nil
|
||||
}
|
||||
|
||||
func (r *relayChat) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
chatProvider, ok := r.provider.(providersBase.ChatInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.chatRequest.Model = r.modelName
|
||||
|
||||
if r.chatRequest.Stream {
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, err = chatProvider.CreateChatCompletionStream(&r.chatRequest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response)
|
||||
} else {
|
||||
var response *types.ChatCompletionResponse
|
||||
response, err = chatProvider.CreateChatCompletion(&r.chatRequest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
76
controller/relay/completions.go
Normal file
76
controller/relay/completions.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayCompletions struct {
|
||||
relayBase
|
||||
request types.CompletionRequest
|
||||
}
|
||||
|
||||
func NewRelayCompletions(c *gin.Context) *relayCompletions {
|
||||
relay := &relayCompletions{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayCompletions) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.request.MaxTokens < 0 || r.request.MaxTokens > math.MaxInt32/2 {
|
||||
return errors.New("max_tokens is invalid")
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayCompletions) getPromptTokens() (int, error) {
|
||||
return common.CountTokenInput(r.request.Prompt, r.modelName), nil
|
||||
}
|
||||
|
||||
func (r *relayCompletions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.CompletionInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
if r.request.Stream {
|
||||
var response requester.StreamReaderInterface[string]
|
||||
response, err = provider.CreateCompletionStream(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = responseStreamClient(r.c, response)
|
||||
} else {
|
||||
var response *types.CompletionResponse
|
||||
response, err = provider.CreateCompletion(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
63
controller/relay/embeddings.go
Normal file
63
controller/relay/embeddings.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayEmbeddings struct {
|
||||
relayBase
|
||||
request types.EmbeddingRequest
|
||||
}
|
||||
|
||||
func NewRelayEmbeddings(c *gin.Context) *relayEmbeddings {
|
||||
relay := &relayEmbeddings{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayEmbeddings) setRequest() error {
|
||||
if strings.HasSuffix(r.c.Request.URL.Path, "embeddings") {
|
||||
r.request.Model = r.c.Param("model")
|
||||
}
|
||||
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayEmbeddings) getPromptTokens() (int, error) {
|
||||
return common.CountTokenInput(r.request.Input, r.modelName), nil
|
||||
}
|
||||
|
||||
func (r *relayEmbeddings) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.EmbeddingsInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateEmbeddings(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
71
controller/relay/image-edits.go
Normal file
71
controller/relay/image-edits.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayImageEdits struct {
|
||||
relayBase
|
||||
request types.ImageEditRequest
|
||||
}
|
||||
|
||||
func NewRelayImageEdits(c *gin.Context) *relayImageEdits {
|
||||
relay := &relayImageEdits{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayImageEdits) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.request.Prompt == "" {
|
||||
return errors.New("field prompt is required")
|
||||
}
|
||||
|
||||
if r.request.Model == "" {
|
||||
r.request.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if r.request.Size == "" {
|
||||
r.request.Size = "1024x1024"
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayImageEdits) getPromptTokens() (int, error) {
|
||||
return common.CountTokenImage(r.request)
|
||||
}
|
||||
|
||||
func (r *relayImageEdits) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.ImageEditsInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateImageEdits(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
74
controller/relay/image-generations.go
Normal file
74
controller/relay/image-generations.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayImageGenerations struct {
|
||||
relayBase
|
||||
request types.ImageRequest
|
||||
}
|
||||
|
||||
func NewRelayImageGenerations(c *gin.Context) *relayImageGenerations {
|
||||
relay := &relayImageGenerations{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayImageGenerations) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.request.Model == "" {
|
||||
r.request.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if r.request.N == 0 {
|
||||
r.request.N = 1
|
||||
}
|
||||
|
||||
if r.request.Size == "" {
|
||||
r.request.Size = "1024x1024"
|
||||
}
|
||||
|
||||
if r.request.Quality == "" {
|
||||
r.request.Quality = "standard"
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayImageGenerations) getPromptTokens() (int, error) {
|
||||
return common.CountTokenImage(r.request)
|
||||
}
|
||||
|
||||
func (r *relayImageGenerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.ImageGenerationsInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateImageGenerations(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
66
controller/relay/image-variationsy.go
Normal file
66
controller/relay/image-variationsy.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayImageVariations struct {
|
||||
relayBase
|
||||
request types.ImageEditRequest
|
||||
}
|
||||
|
||||
func NewRelayImageVariations(c *gin.Context) *relayImageVariations {
|
||||
relay := &relayImageVariations{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayImageVariations) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.request.Model == "" {
|
||||
r.request.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if r.request.Size == "" {
|
||||
r.request.Size = "1024x1024"
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayImageVariations) getPromptTokens() (int, error) {
|
||||
return common.CountTokenImage(r.request)
|
||||
}
|
||||
|
||||
func (r *relayImageVariations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.ImageVariationsInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateImageVariations(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
106
controller/relay/main.go
Normal file
106
controller/relay/main.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relay := Path2Relay(c, c.Request.URL.Path)
|
||||
if relay == nil {
|
||||
common.AbortWithMessage(c, http.StatusNotFound, "Not Found")
|
||||
return
|
||||
}
|
||||
|
||||
if err := relay.setRequest(); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
apiErr, done := RelayHandler(relay)
|
||||
if apiErr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
channel := relay.getProvider().GetChannel()
|
||||
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
|
||||
|
||||
retryTimes := common.RetryTimes
|
||||
if done || !shouldRetry(c, apiErr.StatusCode) {
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error happen, status code is %d, won't retry in this case", apiErr.StatusCode))
|
||||
retryTimes = 0
|
||||
}
|
||||
|
||||
for i := retryTimes; i > 0; i-- {
|
||||
// 冻结通道
|
||||
model.ChannelGroup.Cooldowns(channel.Id)
|
||||
if err := relay.setProvider(relay.getOriginalModel()); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
channel = relay.getProvider().GetChannel()
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("using channel #%d(%s) to retry (remain times %d)", channel.Id, channel.Name, i))
|
||||
apiErr, done = RelayHandler(relay)
|
||||
if apiErr == nil {
|
||||
return
|
||||
}
|
||||
go processChannelRelayError(c.Request.Context(), channel.Id, channel.Name, apiErr)
|
||||
if done || !shouldRetry(c, apiErr.StatusCode) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if apiErr != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
if apiErr.StatusCode == http.StatusTooManyRequests {
|
||||
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
|
||||
c.JSON(apiErr.StatusCode, gin.H{
|
||||
"error": apiErr.OpenAIError,
|
||||
})
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func RelayHandler(relay RelayBaseInterface) (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
promptTokens, tonkeErr := relay.getPromptTokens()
|
||||
if tonkeErr != nil {
|
||||
err = common.ErrorWrapper(tonkeErr, "token_error", http.StatusBadRequest)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
usage := &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
}
|
||||
|
||||
relay.getProvider().SetUsage(usage)
|
||||
|
||||
var quotaInfo *QuotaInfo
|
||||
quotaInfo, err = generateQuotaInfo(relay.getContext(), relay.getModelName(), promptTokens)
|
||||
if err != nil {
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
err, done = relay.send()
|
||||
|
||||
if err != nil {
|
||||
quotaInfo.undo(relay.getContext())
|
||||
return
|
||||
}
|
||||
|
||||
quotaInfo.consume(relay.getContext(), usage)
|
||||
return
|
||||
}
|
||||
62
controller/relay/moderations.go
Normal file
62
controller/relay/moderations.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayModerations struct {
|
||||
relayBase
|
||||
request types.ModerationRequest
|
||||
}
|
||||
|
||||
func NewRelayModerations(c *gin.Context) *relayModerations {
|
||||
relay := &relayModerations{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayModerations) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if r.request.Model == "" {
|
||||
r.request.Model = "text-moderation-stable"
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayModerations) getPromptTokens() (int, error) {
|
||||
return common.CountTokenInput(r.request.Input, r.modelName), nil
|
||||
}
|
||||
|
||||
func (r *relayModerations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.ModerationInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateModeration(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseJsonClient(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -144,7 +144,7 @@ func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName stri
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
func (q *QuotaInfo) undo(c *gin.Context) {
|
||||
tokenId := c.GetInt("token_id")
|
||||
if q.HandelStatus {
|
||||
go func(ctx context.Context) {
|
||||
@@ -155,7 +155,6 @@ func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatu
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
errorHelper(c, errWithCode)
|
||||
}
|
||||
|
||||
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
|
||||
58
controller/relay/speech.go
Normal file
58
controller/relay/speech.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relaySpeech struct {
|
||||
relayBase
|
||||
request types.SpeechAudioRequest
|
||||
}
|
||||
|
||||
func NewRelaySpeech(c *gin.Context) *relaySpeech {
|
||||
relay := &relaySpeech{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relaySpeech) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relaySpeech) getPromptTokens() (int, error) {
|
||||
return len(r.request.Input), nil
|
||||
}
|
||||
|
||||
func (r *relaySpeech) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.SpeechInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateSpeech(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseMultipart(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
58
controller/relay/transcriptions.go
Normal file
58
controller/relay/transcriptions.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayTranscriptions struct {
|
||||
relayBase
|
||||
request types.AudioRequest
|
||||
}
|
||||
|
||||
func NewRelayTranscriptions(c *gin.Context) *relayTranscriptions {
|
||||
relay := &relayTranscriptions{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayTranscriptions) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayTranscriptions) getPromptTokens() (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *relayTranscriptions) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.TranscriptionsInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateTranscriptions(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseCustom(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
58
controller/relay/translations.go
Normal file
58
controller/relay/translations.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type relayTranslations struct {
|
||||
relayBase
|
||||
request types.AudioRequest
|
||||
}
|
||||
|
||||
func NewRelayTranslations(c *gin.Context) *relayTranslations {
|
||||
relay := &relayTranslations{}
|
||||
relay.c = c
|
||||
return relay
|
||||
}
|
||||
|
||||
func (r *relayTranslations) setRequest() error {
|
||||
if err := common.UnmarshalBodyReusable(r.c, &r.request); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.originalModel = r.request.Model
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *relayTranslations) getPromptTokens() (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (r *relayTranslations) send() (err *types.OpenAIErrorWithStatusCode, done bool) {
|
||||
provider, ok := r.provider.(providersBase.TranslationInterface)
|
||||
if !ok {
|
||||
err = common.StringErrorWrapper("channel not implemented", "channel_error", http.StatusServiceUnavailable)
|
||||
done = true
|
||||
return
|
||||
}
|
||||
|
||||
r.request.Model = r.modelName
|
||||
|
||||
response, err := provider.CreateTranslation(&r.request)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = responseCustom(r.c, response)
|
||||
|
||||
if err != nil {
|
||||
done = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package controller
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -8,127 +9,98 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/controller"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
providersBase "one-api/providers/base"
|
||||
"one-api/types"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
|
||||
func Path2Relay(c *gin.Context, path string) RelayBaseInterface {
|
||||
if strings.HasPrefix(path, "/v1/chat/completions") {
|
||||
return NewRelayChat(c)
|
||||
} else if strings.HasPrefix(path, "/v1/completions") {
|
||||
return NewRelayCompletions(c)
|
||||
} else if strings.HasPrefix(path, "/v1/embeddings") {
|
||||
return NewRelayEmbeddings(c)
|
||||
} else if strings.HasPrefix(path, "/v1/moderations") {
|
||||
return NewRelayModerations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/generations") {
|
||||
return NewRelayImageGenerations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/edits") {
|
||||
return NewRelayImageEdits(c)
|
||||
} else if strings.HasPrefix(path, "/v1/images/variations") {
|
||||
return NewRelayImageVariations(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/speech") {
|
||||
return NewRelaySpeech(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
|
||||
return NewRelayTranscriptions(c)
|
||||
} else if strings.HasPrefix(path, "/v1/audio/translations") {
|
||||
return NewRelayTranslations(c)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail error) {
|
||||
channel, fail := fetchChannel(c, modeName)
|
||||
if fail {
|
||||
if fail != nil {
|
||||
return
|
||||
}
|
||||
c.Set("channel_id", channel.Id)
|
||||
|
||||
provider = providers.GetProvider(channel, c)
|
||||
if provider == nil {
|
||||
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
|
||||
fail = true
|
||||
fail = errors.New("channel not found")
|
||||
return
|
||||
}
|
||||
provider.SetOriginalModel(modeName)
|
||||
|
||||
newModelName, err := provider.ModelMappingHandler(modeName)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
fail = true
|
||||
newModelName, fail = provider.ModelMappingHandler(modeName)
|
||||
if fail != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetValidFieldName(err error, obj interface{}) string {
|
||||
getObj := reflect.TypeOf(obj)
|
||||
if errs, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, e := range errs {
|
||||
if f, exist := getObj.Elem().FieldByName(e.Field()); exist {
|
||||
return f.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
|
||||
channelId := c.GetInt("channelId")
|
||||
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
channel, fail = fetchChannelById(c, channelId)
|
||||
if fail {
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
channel, fail = fetchChannelByModel(c, modelName)
|
||||
if fail {
|
||||
return
|
||||
return fetchChannelById(channelId)
|
||||
}
|
||||
|
||||
c.Set("channel_id", channel.Id)
|
||||
|
||||
return
|
||||
return fetchChannelByModel(c, modelName)
|
||||
}
|
||||
|
||||
func fetchChannelById(c *gin.Context, channelId int) (*model.Channel, bool) {
|
||||
func fetchChannelById(channelId int) (*model.Channel, error) {
|
||||
channel, err := model.GetChannelById(channelId, true)
|
||||
if err != nil {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return nil, true
|
||||
return nil, errors.New("无效的渠道 Id")
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
common.AbortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return nil, true
|
||||
return nil, errors.New("该渠道已被禁用")
|
||||
}
|
||||
|
||||
return channel, false
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool) {
|
||||
func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, error) {
|
||||
group := c.GetString("group")
|
||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, modelName)
|
||||
channel, err := model.ChannelGroup.Next(group, modelName)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", group, modelName)
|
||||
if channel != nil {
|
||||
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
common.AbortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
return nil, true
|
||||
return nil, errors.New(message)
|
||||
}
|
||||
|
||||
return channel, false
|
||||
}
|
||||
|
||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
|
||||
@@ -201,3 +173,30 @@ func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func shouldRetry(c *gin.Context, statusCode int) bool {
|
||||
channelId := c.GetInt("specific_channel_id")
|
||||
if channelId > 0 {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if statusCode/100 == 5 {
|
||||
return true
|
||||
}
|
||||
if statusCode == http.StatusBadRequest {
|
||||
return false
|
||||
}
|
||||
if statusCode/100 == 2 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *types.OpenAIErrorWithStatusCode) {
|
||||
common.LogError(ctx, fmt.Sprintf("relay error (channel #%d(%s)): %s", channelId, channelName, err.Message))
|
||||
if controller.ShouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||
controller.DisableChannel(channelId, channelName, err.Message)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user