mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 09:36:37 +08:00
fix: models api return models in deactivate channels
- Enhance logging functionality by adding context support and improving debugging options. - Standardize function naming conventions across middleware to ensure consistency. - Optimize data retrieval and handling in the model controller, including caching and error management. - Simplify the bug report template to streamline the issue reporting process.
This commit is contained in:
parent
f5d4ff05dc
commit
5905a7f295
3
.github/ISSUE_TEMPLATE/bug_report.md
vendored
3
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@ -12,9 +12,8 @@ assignees: ""
|
||||
|
||||
- [ ] I have confirmed there are no similar issues
|
||||
- [ ] I have confirmed I am using the latest version
|
||||
- [ ] I have thoroughly read the project README, especially the FAQ section
|
||||
- [ ] I have thoroughly read the project README
|
||||
- [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback
|
||||
- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed**
|
||||
|
||||
## Issue Description
|
||||
|
||||
|
@ -3,10 +3,11 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sort"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/middleware"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
relay "github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
@ -34,16 +35,20 @@ type OpenAIModelPermission struct {
|
||||
}
|
||||
|
||||
type OpenAIModels struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
// Id model's name
|
||||
//
|
||||
// BUG: Different channels may have the same model name
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int `json:"created"`
|
||||
// OwnedBy is the channel's adaptor name
|
||||
OwnedBy string `json:"owned_by"`
|
||||
Permission []OpenAIModelPermission `json:"permission"`
|
||||
Root string `json:"root"`
|
||||
Parent *string `json:"parent"`
|
||||
}
|
||||
|
||||
var models []OpenAIModels
|
||||
var allModels []OpenAIModels
|
||||
var modelsMap map[string]OpenAIModels
|
||||
var channelId2Models map[int][]string
|
||||
|
||||
@ -76,7 +81,7 @@ func init() {
|
||||
channelName := adaptor.GetChannelName()
|
||||
modelNames := adaptor.GetModelList()
|
||||
for _, modelName := range modelNames {
|
||||
models = append(models, OpenAIModels{
|
||||
allModels = append(allModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@ -93,7 +98,7 @@ func init() {
|
||||
}
|
||||
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
|
||||
for _, modelName := range channelModelList {
|
||||
models = append(models, OpenAIModels{
|
||||
allModels = append(allModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
@ -105,7 +110,7 @@ func init() {
|
||||
}
|
||||
}
|
||||
modelsMap = make(map[string]OpenAIModels)
|
||||
for _, model := range models {
|
||||
for _, model := range allModels {
|
||||
modelsMap[model.Id] = model
|
||||
}
|
||||
channelId2Models = make(map[int][]string)
|
||||
@ -134,49 +139,56 @@ func DashboardListModels(c *gin.Context) {
|
||||
func ListAllModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
"data": allModels,
|
||||
})
|
||||
}
|
||||
|
||||
func ListModels(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
var availableModels []string
|
||||
if c.GetString(ctxkey.AvailableModels) != "" {
|
||||
availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
|
||||
} else {
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
|
||||
userId := c.GetInt(ctxkey.Id)
|
||||
userGroup, err := model.CacheGetUserGroup(userId)
|
||||
if err != nil {
|
||||
middleware.AbortWithError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
modelSet := make(map[string]bool)
|
||||
for _, availableModel := range availableModels {
|
||||
modelSet[availableModel] = true
|
||||
|
||||
// Get available models with their channel names
|
||||
availableAbilities, err := model.GetGroupModelsV2(c.Request.Context(), userGroup)
|
||||
if err != nil {
|
||||
middleware.AbortWithError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a map for quick lookup of enabled model+channel combinations
|
||||
// Only store the exact model:channel combinations from abilities
|
||||
abilityMap := make(map[string]bool)
|
||||
for _, ability := range availableAbilities {
|
||||
// Store as "modelName:channelName" for exact matching
|
||||
adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType))
|
||||
key := ability.Model + ":" + adaptor.GetChannelName()
|
||||
abilityMap[key] = true
|
||||
}
|
||||
|
||||
// Filter models that match user's abilities with EXACT model+channel matches
|
||||
availableOpenAIModels := make([]OpenAIModels, 0)
|
||||
for _, model := range models {
|
||||
if _, ok := modelSet[model.Id]; ok {
|
||||
modelSet[model.Id] = false
|
||||
|
||||
// Only include models that have a matching model+channel combination
|
||||
for _, model := range allModels {
|
||||
key := model.Id + ":" + model.OwnedBy
|
||||
if abilityMap[key] {
|
||||
availableOpenAIModels = append(availableOpenAIModels, model)
|
||||
}
|
||||
}
|
||||
for modelName, ok := range modelSet {
|
||||
if ok {
|
||||
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
|
||||
Id: modelName,
|
||||
Object: "model",
|
||||
Created: 1626777600,
|
||||
OwnedBy: "custom",
|
||||
Root: modelName,
|
||||
Parent: nil,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Sort models alphabetically for consistent presentation
|
||||
sort.Slice(availableOpenAIModels, func(i, j int) bool {
|
||||
return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id
|
||||
})
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"object": "list",
|
||||
"data": availableOpenAIModels,
|
||||
})
|
||||
}
|
||||
|
||||
func RetrieveModel(c *gin.Context) {
|
||||
modelId := c.Param("model")
|
||||
if model, ok := modelsMap[modelId]; ok {
|
||||
|
@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) {
|
||||
key = parts[0]
|
||||
token, err := model.ValidateUserToken(key)
|
||||
if err != nil {
|
||||
abortWithError(c, http.StatusUnauthorized, err)
|
||||
AbortWithError(c, http.StatusUnauthorized, err)
|
||||
return
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
|
||||
AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithError(c, http.StatusInternalServerError, err)
|
||||
AbortWithError(c, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||
abortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil && shouldCheckModel(c) {
|
||||
abortWithError(c, http.StatusBadRequest, err)
|
||||
AbortWithError(c, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
c.Set(ctxkey.RequestModel, requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
|
||||
AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||
} else {
|
||||
abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -32,16 +32,16 @@ func Distribute() func(c *gin.Context) {
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||
AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||
AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||
return
|
||||
}
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
|
||||
AbortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@ -54,7 +54,7 @@ func Distribute() func(c *gin.Context) {
|
||||
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
|
||||
message = "Database consistency has been broken, please contact the administrator"
|
||||
}
|
||||
abortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
||||
AbortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,8 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func abortWithError(c *gin.Context, statusCode int, err error) {
|
||||
// AbortWithError aborts the request with an error message
|
||||
func AbortWithError(c *gin.Context, statusCode int, err error) {
|
||||
logger := gmw.GetLogger(c)
|
||||
logger.Error("server abort", zap.Error(err))
|
||||
c.JSON(statusCode, gin.H{
|
||||
|
@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
gutils "github.com/Laisky/go-utils/v5"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Ability struct {
|
||||
@ -110,3 +112,44 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) {
|
||||
sort.Strings(models)
|
||||
return models, err
|
||||
}
|
||||
|
||||
var getGroupModelsV2Cache = gutils.NewExpCache[[]EnabledAbility](context.Background(), time.Second*10)
|
||||
|
||||
type EnabledAbility struct {
|
||||
Model string `json:"model" gorm:"model"`
|
||||
ChannelType int `json:"channel_type" gorm:"channel_type"`
|
||||
}
|
||||
|
||||
// GetGroupModelsV2 returns all enabled models for this group with their channel names.
|
||||
func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) {
|
||||
// get from cache first
|
||||
if models, ok := getGroupModelsV2Cache.Load(group); ok {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// prepare query based on database type
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
// query with JOIN to get model and channel name in a single query
|
||||
var models []EnabledAbility
|
||||
query := DB.Model(&Ability{}).
|
||||
Select("abilities.model AS model, channels.type AS channel_type").
|
||||
Joins("JOIN channels ON abilities.channel_id = channels.id").
|
||||
Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group).
|
||||
Order("abilities.priority DESC")
|
||||
|
||||
err := query.Find(&models).Error
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get group models")
|
||||
}
|
||||
|
||||
// store in cache
|
||||
getGroupModelsV2Cache.Store(group, models)
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
@ -4,17 +4,18 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/random"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
@ -118,6 +119,11 @@ func InitDB() {
|
||||
return
|
||||
}
|
||||
|
||||
if config.DebugSQLEnabled {
|
||||
logger.Debug(context.TODO(), "debug sql enabled")
|
||||
DB = DB.Debug()
|
||||
}
|
||||
|
||||
sqlDB := setDBConns(DB)
|
||||
|
||||
if !config.IsMasterNode {
|
||||
@ -203,10 +209,6 @@ func migrateLOGDB() error {
|
||||
}
|
||||
|
||||
func setDBConns(db *gorm.DB) *sql.DB {
|
||||
if config.DebugSQLEnabled {
|
||||
db = db.Debug()
|
||||
}
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
logger.FatalLog("failed to connect database: " + err.Error())
|
||||
|
Loading…
Reference in New Issue
Block a user