mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	chore: do not hardcode context key
This commit is contained in:
		@@ -1,7 +1,21 @@
 | 
			
		||||
package ctxkey
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RequestModel     = "request_model"
 | 
			
		||||
	ConvertedRequest = "converted_request"
 | 
			
		||||
	OriginalModel    = "original_model"
 | 
			
		||||
	Id                = "id"
 | 
			
		||||
	Username          = "username"
 | 
			
		||||
	Role              = "role"
 | 
			
		||||
	Status            = "status"
 | 
			
		||||
	Channel           = "channel"
 | 
			
		||||
	ChannelId         = "channel_id"
 | 
			
		||||
	SpecificChannelId = "specific_channel_id"
 | 
			
		||||
	RequestModel      = "request_model"
 | 
			
		||||
	ConvertedRequest  = "converted_request"
 | 
			
		||||
	OriginalModel     = "original_model"
 | 
			
		||||
	Group             = "group"
 | 
			
		||||
	ModelMapping      = "model_mapping"
 | 
			
		||||
	ChannelName       = "channel_name"
 | 
			
		||||
	TokenId           = "token_id"
 | 
			
		||||
	TokenName         = "token_name"
 | 
			
		||||
	BaseURL           = "base_url"
 | 
			
		||||
	AvailableModels   = "available_models"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -136,7 +137,7 @@ func WeChatBind(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		Id: id,
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
)
 | 
			
		||||
@@ -14,13 +15,13 @@ func GetSubscription(c *gin.Context) {
 | 
			
		||||
	var token *model.Token
 | 
			
		||||
	var expiredTime int64
 | 
			
		||||
	if config.DisplayTokenStatEnabled {
 | 
			
		||||
		tokenId := c.GetInt("token_id")
 | 
			
		||||
		tokenId := c.GetInt(ctxkey.TokenId)
 | 
			
		||||
		token, err = model.GetTokenById(tokenId)
 | 
			
		||||
		expiredTime = token.ExpiredTime
 | 
			
		||||
		remainQuota = token.RemainQuota
 | 
			
		||||
		usedQuota = token.UsedQuota
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		remainQuota, err = model.GetUserQuota(userId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			usedQuota, err = model.GetUserUsedQuota(userId)
 | 
			
		||||
@@ -64,11 +65,11 @@ func GetUsage(c *gin.Context) {
 | 
			
		||||
	var err error
 | 
			
		||||
	var token *model.Token
 | 
			
		||||
	if config.DisplayTokenStatEnabled {
 | 
			
		||||
		tokenId := c.GetInt("token_id")
 | 
			
		||||
		tokenId := c.GetInt(ctxkey.TokenId)
 | 
			
		||||
		token, err = model.GetTokenById(tokenId)
 | 
			
		||||
		quota = token.UsedQuota
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		quota, err = model.GetUserUsedQuota(userId)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/message"
 | 
			
		||||
	"github.com/songquanpeng/one-api/middleware"
 | 
			
		||||
@@ -54,8 +55,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
 | 
			
		||||
	}
 | 
			
		||||
	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
			
		||||
	c.Request.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	c.Set("channel", channel.Type)
 | 
			
		||||
	c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
	c.Set(ctxkey.Channel, channel.Type)
 | 
			
		||||
	c.Set(ctxkey.BaseURL, channel.GetBaseURL())
 | 
			
		||||
	middleware.SetupContextForSelectedChannel(c, channel, "")
 | 
			
		||||
	meta := meta.GetByContext(c)
 | 
			
		||||
	apiType := channeltype.ToAPIType(channel.Type)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -41,7 +42,7 @@ func GetUserLogs(c *gin.Context) {
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
	}
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	logType, _ := strconv.Atoi(c.Query("type"))
 | 
			
		||||
	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
@@ -83,7 +84,7 @@ func SearchAllLogs(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func SearchUserLogs(c *gin.Context) {
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	logs, err := model.SearchUserLogs(userId, keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -122,7 +123,7 @@ func GetLogsStat(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLogsSelfStat(c *gin.Context) {
 | 
			
		||||
	username := c.GetString("username")
 | 
			
		||||
	username := c.GetString(ctxkey.Username)
 | 
			
		||||
	logType, _ := strconv.Atoi(c.Query("type"))
 | 
			
		||||
	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
 | 
			
		||||
	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	relay "github.com/songquanpeng/one-api/relay"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
@@ -131,10 +132,10 @@ func ListAllModels(c *gin.Context) {
 | 
			
		||||
func ListModels(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	var availableModels []string
 | 
			
		||||
	if c.GetString("available_models") != "" {
 | 
			
		||||
		availableModels = strings.Split(c.GetString("available_models"), ",")
 | 
			
		||||
	if c.GetString(ctxkey.AvailableModels) != "" {
 | 
			
		||||
		availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
 | 
			
		||||
	}
 | 
			
		||||
@@ -186,7 +187,7 @@ func RetrieveModel(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func GetUserAvailableModels(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	userGroup, err := model.CacheGetUserGroup(id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -109,7 +110,7 @@ func AddRedemption(c *gin.Context) {
 | 
			
		||||
	for i := 0; i < redemption.Count; i++ {
 | 
			
		||||
		key := random.GetUUID()
 | 
			
		||||
		cleanRedemption := model.Redemption{
 | 
			
		||||
			UserId:      c.GetInt("id"),
 | 
			
		||||
			UserId:      c.GetInt(ctxkey.Id),
 | 
			
		||||
			Name:        redemption.Name,
 | 
			
		||||
			Key:         key,
 | 
			
		||||
			CreatedTime: helper.GetTimestamp(),
 | 
			
		||||
 
 | 
			
		||||
@@ -46,15 +46,15 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		requestBody, _ := common.GetRequestBody(c)
 | 
			
		||||
		logger.Debugf(ctx, "request body: %s", string(requestBody))
 | 
			
		||||
	}
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
	bizErr := relayHelper(c, relayMode)
 | 
			
		||||
	if bizErr == nil {
 | 
			
		||||
		monitor.Emit(channelId, true)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	lastFailedChannelId := channelId
 | 
			
		||||
	channelName := c.GetString("channel_name")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
	group := c.GetString(ctxkey.Group)
 | 
			
		||||
	originalModel := c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
	go processChannelRelayError(ctx, channelId, channelName, bizErr)
 | 
			
		||||
	requestId := c.GetString(logger.RequestIdKey)
 | 
			
		||||
@@ -80,9 +80,9 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		if bizErr == nil {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
		lastFailedChannelId = channelId
 | 
			
		||||
		channelName := c.GetString("channel_name")
 | 
			
		||||
		channelName := c.GetString(ctxkey.ChannelName)
 | 
			
		||||
		go processChannelRelayError(ctx, channelId, channelName, bizErr)
 | 
			
		||||
	}
 | 
			
		||||
	if bizErr != nil {
 | 
			
		||||
@@ -97,7 +97,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func shouldRetry(c *gin.Context, statusCode int) bool {
 | 
			
		||||
	if _, ok := c.Get("specific_channel_id"); ok {
 | 
			
		||||
	if _, ok := c.Get(ctxkey.SpecificChannelId); ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if statusCode == http.StatusTooManyRequests {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
@@ -13,7 +14,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAllTokens(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	p, _ := strconv.Atoi(c.Query("p"))
 | 
			
		||||
	if p < 0 {
 | 
			
		||||
		p = 0
 | 
			
		||||
@@ -38,7 +39,7 @@ func GetAllTokens(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SearchTokens(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	keyword := c.Query("keyword")
 | 
			
		||||
	tokens, err := model.SearchUserTokens(userId, keyword)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -58,7 +59,7 @@ func SearchTokens(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func GetToken(c *gin.Context) {
 | 
			
		||||
	id, err := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
@@ -83,8 +84,8 @@ func GetToken(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTokenStatus(c *gin.Context) {
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	tokenId := c.GetInt(ctxkey.TokenId)
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	token, err := model.GetTokenByIds(tokenId, userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -139,7 +140,7 @@ func AddToken(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cleanToken := model.Token{
 | 
			
		||||
		UserId:         c.GetInt("id"),
 | 
			
		||||
		UserId:         c.GetInt(ctxkey.Id),
 | 
			
		||||
		Name:           token.Name,
 | 
			
		||||
		Key:            random.GenerateKey(),
 | 
			
		||||
		CreatedTime:    helper.GetTimestamp(),
 | 
			
		||||
@@ -168,7 +169,7 @@ func AddToken(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func DeleteToken(c *gin.Context) {
 | 
			
		||||
	id, _ := strconv.Atoi(c.Param("id"))
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	err := model.DeleteTokenById(id, userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -185,7 +186,7 @@ func DeleteToken(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateToken(c *gin.Context) {
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	statusOnly := c.Query("status_only")
 | 
			
		||||
	token := model.Token{}
 | 
			
		||||
	err := c.ShouldBindJSON(&token)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -238,7 +239,7 @@ func GetUser(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	myRole := c.GetInt(ctxkey.Role)
 | 
			
		||||
	if myRole <= user.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
@@ -255,7 +256,7 @@ func GetUser(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetUserDashboard(c *gin.Context) {
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
 | 
			
		||||
	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
 | 
			
		||||
@@ -278,7 +279,7 @@ func GetUserDashboard(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GenerateAccessToken(c *gin.Context) {
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	user, err := model.GetUserById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -314,7 +315,7 @@ func GenerateAccessToken(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAffCode(c *gin.Context) {
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	user, err := model.GetUserById(id, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -342,7 +343,7 @@ func GetAffCode(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetSelf(c *gin.Context) {
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	id := c.GetInt(ctxkey.Id)
 | 
			
		||||
	user, err := model.GetUserById(id, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -387,7 +388,7 @@ func UpdateUser(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	myRole := c.GetInt(ctxkey.Role)
 | 
			
		||||
	if myRole <= originUser.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
@@ -445,7 +446,7 @@ func UpdateSelf(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	cleanUser := model.User{
 | 
			
		||||
		Id:          c.GetInt("id"),
 | 
			
		||||
		Id:          c.GetInt(ctxkey.Id),
 | 
			
		||||
		Username:    user.Username,
 | 
			
		||||
		Password:    user.Password,
 | 
			
		||||
		DisplayName: user.DisplayName,
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ import (
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/blacklist"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"net/http"
 | 
			
		||||
@@ -120,20 +121,20 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			abortWithMessage(c, http.StatusBadRequest, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("request_model", requestModel)
 | 
			
		||||
		c.Set(ctxkey.RequestModel, requestModel)
 | 
			
		||||
		if token.Models != nil && *token.Models != "" {
 | 
			
		||||
			c.Set("available_models", *token.Models)
 | 
			
		||||
			c.Set(ctxkey.AvailableModels, *token.Models)
 | 
			
		||||
			if requestModel != "" && !isModelInList(requestModel, *token.Models) {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		c.Set("id", token.UserId)
 | 
			
		||||
		c.Set("token_id", token.Id)
 | 
			
		||||
		c.Set("token_name", token.Name)
 | 
			
		||||
		c.Set(ctxkey.Id, token.UserId)
 | 
			
		||||
		c.Set(ctxkey.TokenId, token.Id)
 | 
			
		||||
		c.Set(ctxkey.TokenName, token.Name)
 | 
			
		||||
		if len(parts) > 1 {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set("specific_channel_id", parts[1])
 | 
			
		||||
				c.Set(ctxkey.SpecificChannelId, parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
 | 
			
		||||
				return
 | 
			
		||||
 
 | 
			
		||||
@@ -17,12 +17,12 @@ type ModelRequest struct {
 | 
			
		||||
 | 
			
		||||
func Distribute() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		c.Set("group", userGroup)
 | 
			
		||||
		c.Set(ctxkey.Group, userGroup)
 | 
			
		||||
		var requestModel string
 | 
			
		||||
		var channel *model.Channel
 | 
			
		||||
		channelId, ok := c.Get("specific_channel_id")
 | 
			
		||||
		channelId, ok := c.Get(ctxkey.SpecificChannelId)
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -39,7 +39,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
			requestModel = c.GetString("request_model")
 | 
			
		||||
			requestModel = c.GetString(ctxkey.RequestModel)
 | 
			
		||||
			var err error
 | 
			
		||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -58,13 +58,13 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
 | 
			
		||||
	c.Set("channel", channel.Type)
 | 
			
		||||
	c.Set("channel_id", channel.Id)
 | 
			
		||||
	c.Set("channel_name", channel.Name)
 | 
			
		||||
	c.Set("model_mapping", channel.GetModelMapping())
 | 
			
		||||
	c.Set(ctxkey.Channel, channel.Type)
 | 
			
		||||
	c.Set(ctxkey.ChannelId, channel.Id)
 | 
			
		||||
	c.Set(ctxkey.ChannelName, channel.Name)
 | 
			
		||||
	c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
 | 
			
		||||
	c.Set(ctxkey.OriginalModel, modelName) // for retry
 | 
			
		||||
	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
	c.Set("base_url", channel.GetBaseURL())
 | 
			
		||||
	c.Set(ctxkey.BaseURL, channel.GetBaseURL())
 | 
			
		||||
	// this is for backward compatibility
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case channeltype.Azure:
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/azure"
 | 
			
		||||
@@ -29,12 +30,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	audioModel := "whisper-1"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
	tokenId := c.GetInt(ctxkey.TokenId)
 | 
			
		||||
	channelType := c.GetInt(ctxkey.Channel)
 | 
			
		||||
	channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	group := c.GetString(ctxkey.Group)
 | 
			
		||||
	tokenName := c.GetString(ctxkey.TokenName)
 | 
			
		||||
 | 
			
		||||
	var ttsRequest openai.TextToSpeechRequest
 | 
			
		||||
	if relayMode == relaymode.AudioSpeech {
 | 
			
		||||
@@ -107,7 +108,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	modelMapping := c.GetString(ctxkey.ModelMapping)
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
@@ -121,8 +122,8 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
 | 
			
		||||
	baseURL := channeltype.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	if c.GetString(ctxkey.BaseURL) != "" {
 | 
			
		||||
		baseURL = c.GetString(ctxkey.BaseURL)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
 | 
			
		||||
 
 | 
			
		||||
@@ -7,6 +7,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay"
 | 
			
		||||
@@ -119,11 +120,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
			logger.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		if quota != 0 {
 | 
			
		||||
			tokenName := c.GetString("token_name")
 | 
			
		||||
			tokenName := c.GetString(ctxkey.TokenName)
 | 
			
		||||
			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelId := c.GetInt(ctxkey.ChannelId)
 | 
			
		||||
			model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
		}
 | 
			
		||||
	}(c.Request.Context())
 | 
			
		||||
 
 | 
			
		||||
@@ -33,14 +33,14 @@ type Meta struct {
 | 
			
		||||
func GetByContext(c *gin.Context) *Meta {
 | 
			
		||||
	meta := Meta{
 | 
			
		||||
		Mode:           relaymode.GetByPath(c.Request.URL.Path),
 | 
			
		||||
		ChannelType:    c.GetInt("channel"),
 | 
			
		||||
		ChannelId:      c.GetInt("channel_id"),
 | 
			
		||||
		TokenId:        c.GetInt("token_id"),
 | 
			
		||||
		TokenName:      c.GetString("token_name"),
 | 
			
		||||
		UserId:         c.GetInt("id"),
 | 
			
		||||
		Group:          c.GetString("group"),
 | 
			
		||||
		ModelMapping:   c.GetStringMapString("model_mapping"),
 | 
			
		||||
		BaseURL:        c.GetString("base_url"),
 | 
			
		||||
		ChannelType:    c.GetInt(ctxkey.Channel),
 | 
			
		||||
		ChannelId:      c.GetInt(ctxkey.ChannelId),
 | 
			
		||||
		TokenId:        c.GetInt(ctxkey.TokenId),
 | 
			
		||||
		TokenName:      c.GetString(ctxkey.TokenName),
 | 
			
		||||
		UserId:         c.GetInt(ctxkey.Id),
 | 
			
		||||
		Group:          c.GetString(ctxkey.Group),
 | 
			
		||||
		ModelMapping:   c.GetStringMapString(ctxkey.ModelMapping),
 | 
			
		||||
		BaseURL:        c.GetString(ctxkey.BaseURL),
 | 
			
		||||
		APIVersion:     c.GetString(ctxkey.ConfigAPIVersion),
 | 
			
		||||
		APIKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
 | 
			
		||||
		Config:         nil,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user