mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-25 01:05:56 +08:00
chore: do not hardcode context key
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
@@ -120,20 +121,20 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set("request_model", requestModel)
|
||||
c.Set(ctxkey.RequestModel, requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
c.Set("available_models", *token.Models)
|
||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
c.Set(ctxkey.Id, token.UserId)
|
||||
c.Set(ctxkey.TokenId, token.Id)
|
||||
c.Set(ctxkey.TokenName, token.Name)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("specific_channel_id", parts[1])
|
||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
return
|
||||
|
||||
@@ -17,12 +17,12 @@ type ModelRequest struct {
|
||||
|
||||
func Distribute() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
c.Set(ctxkey.Group, userGroup)
|
||||
var requestModel string
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
channelId, ok := c.Get(ctxkey.SpecificChannelId)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -39,7 +39,7 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
requestModel = c.GetString("request_model")
|
||||
requestModel = c.GetString(ctxkey.RequestModel)
|
||||
var err error
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||
if err != nil {
|
||||
@@ -58,13 +58,13 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set(ctxkey.Channel, channel.Type)
|
||||
c.Set(ctxkey.ChannelId, channel.Id)
|
||||
c.Set(ctxkey.ChannelName, channel.Name)
|
||||
c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
|
||||
c.Set(ctxkey.OriginalModel, modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
|
||||
// this is for backward compatibility
|
||||
switch channel.Type {
|
||||
case channeltype.Azure:
|
||||
|
||||
Reference in New Issue
Block a user