mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-21 02:56:38 +08:00
fix: models api return models in deactivate channels
- Enhance logging functionality by adding context support and improving debugging options. - Standardize function naming conventions across middleware to ensure consistency. - Optimize data retrieval and handling in the model controller, including caching and error management. - Simplify the bug report template to streamline the issue reporting process.
This commit is contained in:
parent
8df4a2670b
commit
54635ca4fe
21
.github/ISSUE_TEMPLATE/bug_report.md
vendored
21
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -1,20 +1,20 @@
|
|||||||
---
|
---
|
||||||
name: 报告问题
|
name: 报告问题
|
||||||
about: 使用简练详细的语言描述你遇到的问题
|
about: 使用简练详细的语言描述你遇到的问题
|
||||||
title: ''
|
title: ""
|
||||||
labels: bug
|
labels: bug
|
||||||
assignees: ''
|
assignees: ""
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**例行检查**
|
**例行检查**
|
||||||
|
|
||||||
[//]: # (方框内删除已有的空格,填 x 号)
|
[//]: # "方框内删除已有的空格,填 x 号"
|
||||||
+ [ ] 我已确认目前没有类似 issue
|
|
||||||
+ [ ] 我已确认我已升级到最新版本
|
- [ ] 我已确认目前没有类似 issue
|
||||||
+ [ ] 我已完整查看过项目 README,尤其是常见问题部分
|
- [ ] 我已确认我已升级到最新版本
|
||||||
+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
- [ ] 我已完整查看过项目 README,尤其是常见问题部分
|
||||||
+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
- [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈
|
||||||
|
- [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭**
|
||||||
|
|
||||||
**问题描述**
|
**问题描述**
|
||||||
|
|
||||||
@ -23,4 +23,5 @@ assignees: ''
|
|||||||
**预期结果**
|
**预期结果**
|
||||||
|
|
||||||
**相关截图**
|
**相关截图**
|
||||||
如果没有的话,请删除此节。
|
|
||||||
|
如果没有的话,请删除此节。
|
||||||
|
@ -2,8 +2,12 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
|
"github.com/songquanpeng/one-api/middleware"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
relay "github.com/songquanpeng/one-api/relay"
|
relay "github.com/songquanpeng/one-api/relay"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
@ -11,8 +15,6 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/models/list
|
// https://platform.openai.com/docs/api-reference/models/list
|
||||||
@ -33,16 +35,20 @@ type OpenAIModelPermission struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OpenAIModels struct {
|
type OpenAIModels struct {
|
||||||
Id string `json:"id"`
|
// Id model's name
|
||||||
Object string `json:"object"`
|
//
|
||||||
Created int `json:"created"`
|
// BUG: Different channels may have the same model name
|
||||||
|
Id string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int `json:"created"`
|
||||||
|
// OwnedBy is the channel's adaptor name
|
||||||
OwnedBy string `json:"owned_by"`
|
OwnedBy string `json:"owned_by"`
|
||||||
Permission []OpenAIModelPermission `json:"permission"`
|
Permission []OpenAIModelPermission `json:"permission"`
|
||||||
Root string `json:"root"`
|
Root string `json:"root"`
|
||||||
Parent *string `json:"parent"`
|
Parent *string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var models []OpenAIModels
|
var allModels []OpenAIModels
|
||||||
var modelsMap map[string]OpenAIModels
|
var modelsMap map[string]OpenAIModels
|
||||||
var channelId2Models map[int][]string
|
var channelId2Models map[int][]string
|
||||||
|
|
||||||
@ -71,7 +77,7 @@ func init() {
|
|||||||
channelName := adaptor.GetChannelName()
|
channelName := adaptor.GetChannelName()
|
||||||
modelNames := adaptor.GetModelList()
|
modelNames := adaptor.GetModelList()
|
||||||
for _, modelName := range modelNames {
|
for _, modelName := range modelNames {
|
||||||
models = append(models, OpenAIModels{
|
allModels = append(allModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
@ -88,7 +94,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
||||||
for _, modelName := range channelModelList {
|
for _, modelName := range channelModelList {
|
||||||
models = append(models, OpenAIModels{
|
allModels = append(allModels, OpenAIModels{
|
||||||
Id: modelName,
|
Id: modelName,
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1626777600,
|
Created: 1626777600,
|
||||||
@ -100,7 +106,7 @@ func init() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
modelsMap = make(map[string]OpenAIModels)
|
modelsMap = make(map[string]OpenAIModels)
|
||||||
for _, model := range models {
|
for _, model := range allModels {
|
||||||
modelsMap[model.Id] = model
|
modelsMap[model.Id] = model
|
||||||
}
|
}
|
||||||
channelId2Models = make(map[int][]string)
|
channelId2Models = make(map[int][]string)
|
||||||
@ -125,49 +131,56 @@ func DashboardListModels(c *gin.Context) {
|
|||||||
func ListAllModels(c *gin.Context) {
|
func ListAllModels(c *gin.Context) {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": models,
|
"data": allModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ListModels(c *gin.Context) {
|
func ListModels(c *gin.Context) {
|
||||||
ctx := c.Request.Context()
|
userId := c.GetInt(ctxkey.Id)
|
||||||
var availableModels []string
|
userGroup, err := model.CacheGetUserGroup(userId)
|
||||||
if c.GetString(ctxkey.AvailableModels) != "" {
|
if err != nil {
|
||||||
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
middleware.AbortWithError(c, http.StatusBadRequest, err)
|
||||||
} else {
|
return
|
||||||
userId := c.GetInt(ctxkey.Id)
|
|
||||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
|
||||||
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
|
||||||
}
|
}
|
||||||
modelSet := make(map[string]bool)
|
|
||||||
for _, availableModel := range availableModels {
|
// Get available models with their channel names
|
||||||
modelSet[availableModel] = true
|
availableAbilities, err := model.CacheGetGroupModelsV2(c.Request.Context(), userGroup)
|
||||||
|
if err != nil {
|
||||||
|
middleware.AbortWithError(c, http.StatusBadRequest, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a map for quick lookup of enabled model+channel combinations
|
||||||
|
// Only store the exact model:channel combinations from abilities
|
||||||
|
abilityMap := make(map[string]bool)
|
||||||
|
for _, ability := range availableAbilities {
|
||||||
|
// Store as "modelName:channelName" for exact matching
|
||||||
|
adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType))
|
||||||
|
key := ability.Model + ":" + adaptor.GetChannelName()
|
||||||
|
abilityMap[key] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter models that match user's abilities with EXACT model+channel matches
|
||||||
availableOpenAIModels := make([]OpenAIModels, 0)
|
availableOpenAIModels := make([]OpenAIModels, 0)
|
||||||
for _, model := range models {
|
|
||||||
if _, ok := modelSet[model.Id]; ok {
|
// Only include models that have a matching model+channel combination
|
||||||
modelSet[model.Id] = false
|
for _, model := range allModels {
|
||||||
|
key := model.Id + ":" + model.OwnedBy
|
||||||
|
if abilityMap[key] {
|
||||||
availableOpenAIModels = append(availableOpenAIModels, model)
|
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for modelName, ok := range modelSet {
|
|
||||||
if ok {
|
// Sort models alphabetically for consistent presentation
|
||||||
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
sort.Slice(availableOpenAIModels, func(i, j int) bool {
|
||||||
Id: modelName,
|
return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id
|
||||||
Object: "model",
|
})
|
||||||
Created: 1626777600,
|
|
||||||
OwnedBy: "custom",
|
|
||||||
Root: modelName,
|
|
||||||
Parent: nil,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"object": "list",
|
"object": "list",
|
||||||
"data": availableOpenAIModels,
|
"data": availableOpenAIModels,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func RetrieveModel(c *gin.Context) {
|
func RetrieveModel(c *gin.Context) {
|
||||||
modelId := c.Param("model")
|
modelId := c.Param("model")
|
||||||
if model, ok := modelsMap[modelId]; ok {
|
if model, ok := modelsMap[modelId]; ok {
|
||||||
@ -196,7 +209,8 @@ func GetUserAvailableModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
models, err := model.CacheGetGroupModels(ctx, userGroup)
|
|
||||||
|
models, err := model.CacheGetGroupModelsV2(ctx, userGroup)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@ -204,10 +218,20 @@ func GetUserAvailableModels(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var modelNames []string
|
||||||
|
modelsMap := map[string]bool{}
|
||||||
|
for _, model := range models {
|
||||||
|
modelsMap[model.Model] = true
|
||||||
|
}
|
||||||
|
for modelName := range modelsMap {
|
||||||
|
modelNames = append(modelNames, modelName)
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
"data": models,
|
"data": modelNames,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -1,15 +1,16 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common/blacklist"
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/network"
|
"github.com/songquanpeng/one-api/common/network"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func authHelper(c *gin.Context, minRole int) {
|
func authHelper(c *gin.Context, minRole int) {
|
||||||
@ -98,34 +99,34 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
AbortWithError(c, http.StatusUnauthorized, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Subnet != nil && *token.Subnet != "" {
|
if token.Subnet != nil && *token.Subnet != "" {
|
||||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
AbortWithError(c, http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
AbortWithError(c, http.StatusForbidden, errors.New("用户已被封禁"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
requestModel, err := getRequestModel(c)
|
requestModel, err := getRequestModel(c)
|
||||||
if err != nil && shouldCheckModel(c) {
|
if err != nil && shouldCheckModel(c) {
|
||||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
AbortWithError(c, http.StatusBadRequest, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set(ctxkey.RequestModel, requestModel)
|
c.Set(ctxkey.RequestModel, requestModel)
|
||||||
if token.Models != nil && *token.Models != "" {
|
if token.Models != nil && *token.Models != "" {
|
||||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌无权使用模型: %s", requestModel))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,7 +137,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
|
AbortWithError(c, http.StatusForbidden, errors.New("普通用户不支持指定渠道"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
@ -29,16 +29,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != model.ChannelStatusEnabled {
|
if channel.Status != model.ChannelStatusEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
AbortWithError(c, http.StatusForbidden, errors.New("该渠道已被禁用"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -51,7 +51,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||||
message = "数据库一致性已被破坏,请联系管理员"
|
message = "数据库一致性已被破坏,请联系管理员"
|
||||||
}
|
}
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
AbortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,11 +2,12 @@ package middleware
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/helper"
|
"github.com/songquanpeng/one-api/common/helper"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||||
@ -20,6 +21,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
logger.Error(c.Request.Context(), message)
|
logger.Error(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AbortWithError aborts the request with an error message
|
||||||
|
func AbortWithError(c *gin.Context, statusCode int, err error) {
|
||||||
|
logger.Errorf(c, "server abort: %+v", err)
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
func getRequestModel(c *gin.Context) (string, error) {
|
func getRequestModel(c *gin.Context) (string, error) {
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
@ -5,10 +5,10 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/utils"
|
"github.com/songquanpeng/one-api/common/utils"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Ability struct {
|
type Ability struct {
|
||||||
@ -110,3 +110,34 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
|||||||
sort.Strings(models)
|
sort.Strings(models)
|
||||||
return models, err
|
return models, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type EnabledAbility struct {
|
||||||
|
Model string `json:"model" gorm:"model"`
|
||||||
|
ChannelType int `json:"channel_type" gorm:"channel_type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGroupModelsV2 returns all enabled models for this group with their channel names.
|
||||||
|
func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) {
|
||||||
|
// prepare query based on database type
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
|
// query with JOIN to get model and channel name in a single query
|
||||||
|
var models []EnabledAbility
|
||||||
|
query := DB.Model(&Ability{}).
|
||||||
|
Select("abilities.model AS model, channels.type AS channel_type").
|
||||||
|
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||||
|
Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group).
|
||||||
|
Order("abilities.priority DESC")
|
||||||
|
|
||||||
|
err := query.Find(&models).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get group models")
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
@ -3,18 +3,19 @@ package model
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/songquanpeng/one-api/common"
|
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
|
||||||
"github.com/songquanpeng/one-api/common/random"
|
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/common"
|
||||||
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
|
"github.com/songquanpeng/one-api/common/random"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -148,7 +149,10 @@ func CacheIsUserEnabled(userId int) (bool, error) {
|
|||||||
return userEnabled, err
|
return userEnabled, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
// CacheGetGroupModels returns models of a group
|
||||||
|
//
|
||||||
|
// Deprecated: use CacheGetGroupModelsV2 instead
|
||||||
|
func CacheGetGroupModels(ctx context.Context, group string) (models []string, err error) {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
return GetGroupModels(ctx, group)
|
return GetGroupModels(ctx, group)
|
||||||
}
|
}
|
||||||
@ -156,7 +160,7 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return strings.Split(modelsStr, ","), nil
|
return strings.Split(modelsStr, ","), nil
|
||||||
}
|
}
|
||||||
models, err := GetGroupModels(ctx, group)
|
models, err = GetGroupModels(ctx, group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -167,6 +171,42 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
|
|||||||
return models, nil
|
return models, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CacheGetGroupModelsV2 is a version of CacheGetGroupModels that returns EnabledAbility instead of string
|
||||||
|
func CacheGetGroupModelsV2(ctx context.Context, group string) (models []EnabledAbility, err error) {
|
||||||
|
if !common.RedisEnabled {
|
||||||
|
return GetGroupModelsV2(ctx, group)
|
||||||
|
}
|
||||||
|
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models_v2:%s", group))
|
||||||
|
if err != nil {
|
||||||
|
logger.Warnf(ctx, "Redis get group models error: %+v", err)
|
||||||
|
} else {
|
||||||
|
if err = json.Unmarshal([]byte(modelsStr), &models); err != nil {
|
||||||
|
logger.Warnf(ctx, "Redis get group models error: %+v", err)
|
||||||
|
} else {
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
models, err = GetGroupModelsV2(ctx, group)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "get group models")
|
||||||
|
}
|
||||||
|
|
||||||
|
cachePayload, err := json.Marshal(models)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("Redis set group models error: " + err.Error())
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), string(cachePayload),
|
||||||
|
time.Duration(GroupModelsCacheSeconds)*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
logger.SysError("Redis set group models error: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|
||||||
var group2model2channels map[string]map[string][]*Channel
|
var group2model2channels map[string]map[string][]*Channel
|
||||||
var channelSyncLock sync.RWMutex
|
var channelSyncLock sync.RWMutex
|
||||||
|
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
"github.com/songquanpeng/one-api/common/env"
|
"github.com/songquanpeng/one-api/common/env"
|
||||||
@ -13,9 +18,6 @@ import (
|
|||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/driver/sqlite"
|
"gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var DB *gorm.DB
|
var DB *gorm.DB
|
||||||
@ -116,6 +118,11 @@ func InitDB() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.DebugSQLEnabled {
|
||||||
|
logger.Debug(context.TODO(), "debug sql enabled")
|
||||||
|
DB = DB.Debug()
|
||||||
|
}
|
||||||
|
|
||||||
sqlDB := setDBConns(DB)
|
sqlDB := setDBConns(DB)
|
||||||
|
|
||||||
if !config.IsMasterNode {
|
if !config.IsMasterNode {
|
||||||
@ -201,10 +208,6 @@ func migrateLOGDB() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func setDBConns(db *gorm.DB) *sql.DB {
|
func setDBConns(db *gorm.DB) *sql.DB {
|
||||||
if config.DebugSQLEnabled {
|
|
||||||
db = db.Debug()
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlDB, err := db.DB()
|
sqlDB, err := db.DB()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.FatalLog("failed to connect database: " + err.Error())
|
logger.FatalLog("failed to connect database: " + err.Error())
|
||||||
|
Loading…
Reference in New Issue
Block a user