mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-12 19:33:41 +08:00
fix: models api return models in deactivate channels
- Enhance logging functionality by adding context support and improving debugging options. - Standardize function naming conventions across middleware to ensure consistency. - Optimize data retrieval and handling in the model controller, including caching and error management. - Simplify the bug report template to streamline the issue reporting process.
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func authHelper(c *gin.Context, minRole int) {
|
||||
@@ -98,34 +99,34 @@ func TokenAuth() func(c *gin.Context) {
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
AbortWithError(c, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||
AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
AbortWithError(c, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("用户已被封禁"))
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil && shouldCheckModel(c) {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
AbortWithError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
c.Set(ctxkey.RequestModel, requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌无权使用模型: %s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -136,7 +137,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||
} else {
|
||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("普通用户不支持指定渠道"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
@@ -29,16 +29,16 @@ func Distribute() func(c *gin.Context) {
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id"))
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id"))
|
||||
return
|
||||
}
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("该渠道已被禁用"))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -51,7 +51,7 @@ func Distribute() func(c *gin.Context) {
|
||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
}
|
||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
||||
AbortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,11 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
@@ -20,6 +21,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
// AbortWithError aborts the request with an error message
|
||||
func AbortWithError(c *gin.Context, statusCode int, err error) {
|
||||
logger.Errorf(c, "server abort: %+v", err)
|
||||
c.JSON(statusCode, gin.H{
|
||||
"error": gin.H{
|
||||
"message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)),
|
||||
"type": "one_api_error",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
|
||||
Reference in New Issue
Block a user