mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
- Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON. - Enhance the functionality of request conversion functions by including context parameters and robust error wrapping. - Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation.
126 lines
3.9 KiB
Go
126 lines
3.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
|
|
gutils "github.com/Laisky/go-utils/v5"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/pkg/errors"
|
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
|
"github.com/songquanpeng/one-api/common/logger"
|
|
"github.com/songquanpeng/one-api/model"
|
|
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
|
)
|
|
|
|
type ModelRequest struct {
|
|
Model string `json:"model" form:"model"`
|
|
}
|
|
|
|
func Distribute() func(c *gin.Context) {
|
|
return func(c *gin.Context) {
|
|
ctx := c.Request.Context()
|
|
userId := c.GetInt(ctxkey.Id)
|
|
userGroup, _ := model.CacheGetUserGroup(userId)
|
|
c.Set(ctxkey.Group, userGroup)
|
|
var requestModel string
|
|
var channel *model.Channel
|
|
channelId, ok := c.Get(ctxkey.SpecificChannelId)
|
|
if ok {
|
|
id, err := strconv.Atoi(channelId.(string))
|
|
if err != nil {
|
|
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
|
return
|
|
}
|
|
channel, err = model.GetChannelById(id, true)
|
|
if err != nil {
|
|
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
|
return
|
|
}
|
|
if channel.Status != model.ChannelStatusEnabled {
|
|
abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
|
|
return
|
|
}
|
|
} else {
|
|
requestModel = c.GetString(ctxkey.RequestModel)
|
|
var err error
|
|
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
|
if err != nil {
|
|
message := fmt.Sprintf("No available channels for Model %s under Group %s", requestModel, userGroup)
|
|
if channel != nil {
|
|
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
|
|
message = "Database consistency has been broken, please contact the administrator"
|
|
}
|
|
abortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
|
return
|
|
}
|
|
}
|
|
logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
|
|
SetupContextForSelectedChannel(c, channel, requestModel)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
|
// one channel could relates to multiple groups,
|
|
// and each groud has individual ratio,
|
|
// set minimal group ratio as channel_ratio
|
|
var minimalRatio float64 = -1
|
|
for _, grp := range strings.Split(channel.Group, ",") {
|
|
v := ratio.GetGroupRatio(grp)
|
|
if minimalRatio < 0 || v < minimalRatio {
|
|
minimalRatio = v
|
|
}
|
|
}
|
|
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
|
|
c.Set(ctxkey.ChannelRatio, minimalRatio)
|
|
c.Set(ctxkey.ChannelModel, channel)
|
|
|
|
// generate an unique cost id for each request
|
|
if _, ok := c.Get(ctxkey.RequestId); !ok {
|
|
c.Set(ctxkey.RequestId, gutils.UUID7())
|
|
}
|
|
|
|
c.Set(ctxkey.Channel, channel.Type)
|
|
c.Set(ctxkey.ChannelId, channel.Id)
|
|
c.Set(ctxkey.ChannelName, channel.Name)
|
|
c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type"))
|
|
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
|
|
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
|
|
}
|
|
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
|
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
|
cfg, _ := channel.LoadConfig()
|
|
// this is for backward compatibility
|
|
if channel.Other != nil {
|
|
switch channel.Type {
|
|
case channeltype.Azure:
|
|
if cfg.APIVersion == "" {
|
|
cfg.APIVersion = *channel.Other
|
|
}
|
|
case channeltype.Xunfei:
|
|
if cfg.APIVersion == "" {
|
|
cfg.APIVersion = *channel.Other
|
|
}
|
|
case channeltype.Gemini:
|
|
if cfg.APIVersion == "" {
|
|
cfg.APIVersion = *channel.Other
|
|
}
|
|
case channeltype.AIProxyLibrary:
|
|
if cfg.LibraryID == "" {
|
|
cfg.LibraryID = *channel.Other
|
|
}
|
|
case channeltype.Ali:
|
|
if cfg.Plugin == "" {
|
|
cfg.Plugin = *channel.Other
|
|
}
|
|
}
|
|
}
|
|
c.Set(ctxkey.Config, cfg)
|
|
}
|