fix: models api return models in deactivate channels

- Enhance logging functionality by adding context support and improving debugging options.
- Standardize function naming conventions across middleware to ensure consistency.
- Optimize data retrieval and handling in the model controller, including caching and error management.
- Simplify the bug report template to streamline the issue reporting process.
This commit is contained in:
Laisky.Cai 2025-02-26 11:22:03 +00:00
parent f5d4ff05dc
commit 5905a7f295
8 changed files with 119 additions and 61 deletions

View File

@ -12,9 +12,8 @@ assignees: ""
- [ ] I have confirmed there are no similar issues - [ ] I have confirmed there are no similar issues
- [ ] I have confirmed I am using the latest version - [ ] I have confirmed I am using the latest version
- [ ] I have thoroughly read the project README, especially the FAQ section - [ ] I have thoroughly read the project README
- [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback - [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback
- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed**
## Issue Description ## Issue Description

View File

@ -3,10 +3,11 @@ package controller
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "sort"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
relay "github.com/songquanpeng/one-api/relay" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
@ -34,16 +35,20 @@ type OpenAIModelPermission struct {
} }
type OpenAIModels struct { type OpenAIModels struct {
Id string `json:"id"` // Id model's name
Object string `json:"object"` //
Created int `json:"created"` // BUG: Different channels may have the same model name
Id string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
// OwnedBy is the channel's adaptor name
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
Permission []OpenAIModelPermission `json:"permission"` Permission []OpenAIModelPermission `json:"permission"`
Root string `json:"root"` Root string `json:"root"`
Parent *string `json:"parent"` Parent *string `json:"parent"`
} }
var models []OpenAIModels var allModels []OpenAIModels
var modelsMap map[string]OpenAIModels var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string var channelId2Models map[int][]string
@ -76,7 +81,7 @@ func init() {
channelName := adaptor.GetChannelName() channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList() modelNames := adaptor.GetModelList()
for _, modelName := range modelNames { for _, modelName := range modelNames {
models = append(models, OpenAIModels{ allModels = append(allModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@ -93,7 +98,7 @@ func init() {
} }
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList { for _, modelName := range channelModelList {
models = append(models, OpenAIModels{ allModels = append(allModels, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@ -105,7 +110,7 @@ func init() {
} }
} }
modelsMap = make(map[string]OpenAIModels) modelsMap = make(map[string]OpenAIModels)
for _, model := range models { for _, model := range allModels {
modelsMap[model.Id] = model modelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string) channelId2Models = make(map[int][]string)
@ -134,49 +139,56 @@ func DashboardListModels(c *gin.Context) {
func ListAllModels(c *gin.Context) { func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"object": "list", "object": "list",
"data": models, "data": allModels,
}) })
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
ctx := c.Request.Context() userId := c.GetInt(ctxkey.Id)
var availableModels []string userGroup, err := model.CacheGetUserGroup(userId)
if c.GetString(ctxkey.AvailableModels) != "" { if err != nil {
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") middleware.AbortWithError(c, http.StatusBadRequest, err)
} else { return
userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
} }
modelSet := make(map[string]bool)
for _, availableModel := range availableModels { // Get available models with their channel names
modelSet[availableModel] = true availableAbilities, err := model.GetGroupModelsV2(c.Request.Context(), userGroup)
if err != nil {
middleware.AbortWithError(c, http.StatusBadRequest, err)
return
} }
// Create a map for quick lookup of enabled model+channel combinations
// Only store the exact model:channel combinations from abilities
abilityMap := make(map[string]bool)
for _, ability := range availableAbilities {
// Store as "modelName:channelName" for exact matching
adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType))
key := ability.Model + ":" + adaptor.GetChannelName()
abilityMap[key] = true
}
// Filter models that match user's abilities with EXACT model+channel matches
availableOpenAIModels := make([]OpenAIModels, 0) availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range models {
if _, ok := modelSet[model.Id]; ok { // Only include models that have a matching model+channel combination
modelSet[model.Id] = false for _, model := range allModels {
key := model.Id + ":" + model.OwnedBy
if abilityMap[key] {
availableOpenAIModels = append(availableOpenAIModels, model) availableOpenAIModels = append(availableOpenAIModels, model)
} }
} }
for modelName, ok := range modelSet {
if ok { // Sort models alphabetically for consistent presentation
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ sort.Slice(availableOpenAIModels, func(i, j int) bool {
Id: modelName, return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id
Object: "model", })
Created: 1626777600,
OwnedBy: "custom",
Root: modelName,
Parent: nil,
})
}
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"object": "list", "object": "list",
"data": availableOpenAIModels, "data": availableOpenAIModels,
}) })
} }
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelId := c.Param("model") modelId := c.Param("model")
if model, ok := modelsMap[modelId]; ok { if model, ok := modelsMap[modelId]; ok {

View File

@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if err != nil { if err != nil {
abortWithError(c, http.StatusUnauthorized, err) AbortWithError(c, http.StatusUnauthorized, err)
return return
} }
if token.Subnet != nil && *token.Subnet != "" { if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
return return
} }
} }
userEnabled, err := model.CacheIsUserEnabled(token.UserId) userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil { if err != nil {
abortWithError(c, http.StatusInternalServerError, err) AbortWithError(c, http.StatusInternalServerError, err)
return return
} }
if !userEnabled || blacklist.IsUserBanned(token.UserId) { if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithError(c, http.StatusForbidden, errors.New("User has been banned")) AbortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
return return
} }
requestModel, err := getRequestModel(c) requestModel, err := getRequestModel(c)
if err != nil && shouldCheckModel(c) { if err != nil && shouldCheckModel(c) {
abortWithError(c, http.StatusBadRequest, err) AbortWithError(c, http.StatusBadRequest, err)
return return
} }
c.Set(ctxkey.RequestModel, requestModel) c.Set(ctxkey.RequestModel, requestModel)
if token.Models != nil && *token.Models != "" { if token.Models != nil && *token.Models != "" {
c.Set(ctxkey.AvailableModels, *token.Models) c.Set(ctxkey.AvailableModels, *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) { if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel)) AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
return return
} }
} }
@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set(ctxkey.SpecificChannelId, parts[1]) c.Set(ctxkey.SpecificChannelId, parts[1])
} else { } else {
abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels")) AbortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
return return
} }
} }

View File

@ -32,16 +32,16 @@ func Distribute() func(c *gin.Context) {
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
return return
} }
channel, err = model.GetChannelById(id, true) channel, err = model.GetChannelById(id, true)
if err != nil { if err != nil {
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
return return
} }
if channel.Status != model.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled")) AbortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
return return
} }
} else { } else {
@ -54,7 +54,7 @@ func Distribute() func(c *gin.Context) {
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id)) logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
message = "Database consistency has been broken, please contact the administrator" message = "Database consistency has been broken, please contact the administrator"
} }
abortWithError(c, http.StatusServiceUnavailable, errors.New(message)) AbortWithError(c, http.StatusServiceUnavailable, errors.New(message))
return return
} }
} }

View File

@ -23,7 +23,8 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
logger.Error(c.Request.Context(), message) logger.Error(c.Request.Context(), message)
} }
func abortWithError(c *gin.Context, statusCode int, err error) { // AbortWithError aborts the request with an error message
func AbortWithError(c *gin.Context, statusCode int, err error) {
logger := gmw.GetLogger(c) logger := gmw.GetLogger(c)
logger.Error("server abort", zap.Error(err)) logger.Error("server abort", zap.Error(err))
c.JSON(statusCode, gin.H{ c.JSON(statusCode, gin.H{

View File

@ -4,11 +4,13 @@ import (
"context" "context"
"sort" "sort"
"strings" "strings"
"time"
"gorm.io/gorm" gutils "github.com/Laisky/go-utils/v5"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/utils" "github.com/songquanpeng/one-api/common/utils"
"gorm.io/gorm"
) )
type Ability struct { type Ability struct {
@ -110,3 +112,44 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) {
sort.Strings(models) sort.Strings(models)
return models, err return models, err
} }
var getGroupModelsV2Cache = gutils.NewExpCache[[]EnabledAbility](context.Background(), time.Second*10)
type EnabledAbility struct {
Model string `json:"model" gorm:"model"`
ChannelType int `json:"channel_type" gorm:"channel_type"`
}
// GetGroupModelsV2 returns all enabled models for this group with their channel names.
func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) {
// get from cache first
if models, ok := getGroupModelsV2Cache.Load(group); ok {
return models, nil
}
// prepare query based on database type
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
// query with JOIN to get model and channel name in a single query
var models []EnabledAbility
query := DB.Model(&Ability{}).
Select("abilities.model AS model, channels.type AS channel_type").
Joins("JOIN channels ON abilities.channel_id = channels.id").
Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group).
Order("abilities.priority DESC")
err := query.Find(&models).Error
if err != nil {
return nil, errors.Wrap(err, "get group models")
}
// store in cache
getGroupModelsV2Cache.Store(group, models)
return models, nil
}

View File

@ -4,17 +4,18 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"math/rand" "math/rand"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
) )
var ( var (

View File

@ -1,6 +1,7 @@
package model package model
import ( import (
"context"
"database/sql" "database/sql"
"fmt" "fmt"
"os" "os"
@ -118,6 +119,11 @@ func InitDB() {
return return
} }
if config.DebugSQLEnabled {
logger.Debug(context.TODO(), "debug sql enabled")
DB = DB.Debug()
}
sqlDB := setDBConns(DB) sqlDB := setDBConns(DB)
if !config.IsMasterNode { if !config.IsMasterNode {
@ -203,10 +209,6 @@ func migrateLOGDB() error {
} }
func setDBConns(db *gorm.DB) *sql.DB { func setDBConns(db *gorm.DB) *sql.DB {
if config.DebugSQLEnabled {
db = db.Debug()
}
sqlDB, err := db.DB() sqlDB, err := db.DB()
if err != nil { if err != nil {
logger.FatalLog("failed to connect database: " + err.Error()) logger.FatalLog("failed to connect database: " + err.Error())