mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	fix: models api return models in deactivate channels
- Enhance logging functionality by adding context support and improving debugging options. - Standardize function naming conventions across middleware to ensure consistency. - Optimize data retrieval and handling in the model controller, including caching and error management. - Simplify the bug report template to streamline the issue reporting process.
This commit is contained in:
		
							
								
								
									
										3
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/ISSUE_TEMPLATE/bug_report.md
									
									
									
									
										vendored
									
									
								
							@@ -12,9 +12,8 @@ assignees: ""
 | 
			
		||||
 | 
			
		||||
- [ ] I have confirmed there are no similar issues
 | 
			
		||||
- [ ] I have confirmed I am using the latest version
 | 
			
		||||
- [ ] I have thoroughly read the project README, especially the FAQ section
 | 
			
		||||
- [ ] I have thoroughly read the project README
 | 
			
		||||
- [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback
 | 
			
		||||
- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed**
 | 
			
		||||
 | 
			
		||||
## Issue Description
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,10 +3,11 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sort"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/middleware"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	relay "github.com/songquanpeng/one-api/relay"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
@@ -34,16 +35,20 @@ type OpenAIModelPermission struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAIModels struct {
 | 
			
		||||
	Id         string                  `json:"id"`
 | 
			
		||||
	Object     string                  `json:"object"`
 | 
			
		||||
	Created    int                     `json:"created"`
 | 
			
		||||
	// Id model's name
 | 
			
		||||
	//
 | 
			
		||||
	// BUG: Different channels may have the same model name
 | 
			
		||||
	Id      string `json:"id"`
 | 
			
		||||
	Object  string `json:"object"`
 | 
			
		||||
	Created int    `json:"created"`
 | 
			
		||||
	// OwnedBy is the channel's adaptor name
 | 
			
		||||
	OwnedBy    string                  `json:"owned_by"`
 | 
			
		||||
	Permission []OpenAIModelPermission `json:"permission"`
 | 
			
		||||
	Root       string                  `json:"root"`
 | 
			
		||||
	Parent     *string                 `json:"parent"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var models []OpenAIModels
 | 
			
		||||
var allModels []OpenAIModels
 | 
			
		||||
var modelsMap map[string]OpenAIModels
 | 
			
		||||
var channelId2Models map[int][]string
 | 
			
		||||
 | 
			
		||||
@@ -76,7 +81,7 @@ func init() {
 | 
			
		||||
		channelName := adaptor.GetChannelName()
 | 
			
		||||
		modelNames := adaptor.GetModelList()
 | 
			
		||||
		for _, modelName := range modelNames {
 | 
			
		||||
			models = append(models, OpenAIModels{
 | 
			
		||||
			allModels = append(allModels, OpenAIModels{
 | 
			
		||||
				Id:         modelName,
 | 
			
		||||
				Object:     "model",
 | 
			
		||||
				Created:    1626777600,
 | 
			
		||||
@@ -93,7 +98,7 @@ func init() {
 | 
			
		||||
		}
 | 
			
		||||
		channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
 | 
			
		||||
		for _, modelName := range channelModelList {
 | 
			
		||||
			models = append(models, OpenAIModels{
 | 
			
		||||
			allModels = append(allModels, OpenAIModels{
 | 
			
		||||
				Id:         modelName,
 | 
			
		||||
				Object:     "model",
 | 
			
		||||
				Created:    1626777600,
 | 
			
		||||
@@ -105,7 +110,7 @@ func init() {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	modelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range models {
 | 
			
		||||
	for _, model := range allModels {
 | 
			
		||||
		modelsMap[model.Id] = model
 | 
			
		||||
	}
 | 
			
		||||
	channelId2Models = make(map[int][]string)
 | 
			
		||||
@@ -134,49 +139,56 @@ func DashboardListModels(c *gin.Context) {
 | 
			
		||||
func ListAllModels(c *gin.Context) {
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"object": "list",
 | 
			
		||||
		"data":   models,
 | 
			
		||||
		"data":   allModels,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ListModels(c *gin.Context) {
 | 
			
		||||
	ctx := c.Request.Context()
 | 
			
		||||
	var availableModels []string
 | 
			
		||||
	if c.GetString(ctxkey.AvailableModels) != "" {
 | 
			
		||||
		availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",")
 | 
			
		||||
	} else {
 | 
			
		||||
		userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
		userGroup, _ := model.CacheGetUserGroup(userId)
 | 
			
		||||
		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
 | 
			
		||||
	userId := c.GetInt(ctxkey.Id)
 | 
			
		||||
	userGroup, err := model.CacheGetUserGroup(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		middleware.AbortWithError(c, http.StatusBadRequest, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	modelSet := make(map[string]bool)
 | 
			
		||||
	for _, availableModel := range availableModels {
 | 
			
		||||
		modelSet[availableModel] = true
 | 
			
		||||
 | 
			
		||||
	// Get available models with their channel names
 | 
			
		||||
	availableAbilities, err := model.GetGroupModelsV2(c.Request.Context(), userGroup)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		middleware.AbortWithError(c, http.StatusBadRequest, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Create a map for quick lookup of enabled model+channel combinations
 | 
			
		||||
	// Only store the exact model:channel combinations from abilities
 | 
			
		||||
	abilityMap := make(map[string]bool)
 | 
			
		||||
	for _, ability := range availableAbilities {
 | 
			
		||||
		// Store as "modelName:channelName" for exact matching
 | 
			
		||||
		adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType))
 | 
			
		||||
		key := ability.Model + ":" + adaptor.GetChannelName()
 | 
			
		||||
		abilityMap[key] = true
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Filter models that match user's abilities with EXACT model+channel matches
 | 
			
		||||
	availableOpenAIModels := make([]OpenAIModels, 0)
 | 
			
		||||
	for _, model := range models {
 | 
			
		||||
		if _, ok := modelSet[model.Id]; ok {
 | 
			
		||||
			modelSet[model.Id] = false
 | 
			
		||||
 | 
			
		||||
	// Only include models that have a matching model+channel combination
 | 
			
		||||
	for _, model := range allModels {
 | 
			
		||||
		key := model.Id + ":" + model.OwnedBy
 | 
			
		||||
		if abilityMap[key] {
 | 
			
		||||
			availableOpenAIModels = append(availableOpenAIModels, model)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	for modelName, ok := range modelSet {
 | 
			
		||||
		if ok {
 | 
			
		||||
			availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
 | 
			
		||||
				Id:      modelName,
 | 
			
		||||
				Object:  "model",
 | 
			
		||||
				Created: 1626777600,
 | 
			
		||||
				OwnedBy: "custom",
 | 
			
		||||
				Root:    modelName,
 | 
			
		||||
				Parent:  nil,
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Sort models alphabetically for consistent presentation
 | 
			
		||||
	sort.Slice(availableOpenAIModels, func(i, j int) bool {
 | 
			
		||||
		return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	c.JSON(200, gin.H{
 | 
			
		||||
		"object": "list",
 | 
			
		||||
		"data":   availableOpenAIModels,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RetrieveModel(c *gin.Context) {
 | 
			
		||||
	modelId := c.Param("model")
 | 
			
		||||
	if model, ok := modelsMap[modelId]; ok {
 | 
			
		||||
 
 | 
			
		||||
@@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
		key = parts[0]
 | 
			
		||||
		token, err := model.ValidateUserToken(key)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			abortWithError(c, http.StatusUnauthorized, err)
 | 
			
		||||
			AbortWithError(c, http.StatusUnauthorized, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if token.Subnet != nil && *token.Subnet != "" {
 | 
			
		||||
			if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
 | 
			
		||||
				abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
 | 
			
		||||
				AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		userEnabled, err := model.CacheIsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			abortWithError(c, http.StatusInternalServerError, err)
 | 
			
		||||
			AbortWithError(c, http.StatusInternalServerError, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled || blacklist.IsUserBanned(token.UserId) {
 | 
			
		||||
			abortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
 | 
			
		||||
			AbortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		requestModel, err := getRequestModel(c)
 | 
			
		||||
		if err != nil && shouldCheckModel(c) {
 | 
			
		||||
			abortWithError(c, http.StatusBadRequest, err)
 | 
			
		||||
			AbortWithError(c, http.StatusBadRequest, err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set(ctxkey.RequestModel, requestModel)
 | 
			
		||||
		if token.Models != nil && *token.Models != "" {
 | 
			
		||||
			c.Set(ctxkey.AvailableModels, *token.Models)
 | 
			
		||||
			if requestModel != "" && !isModelInList(requestModel, *token.Models) {
 | 
			
		||||
				abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
 | 
			
		||||
				AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			if model.IsAdmin(token.UserId) {
 | 
			
		||||
				c.Set(ctxkey.SpecificChannelId, parts[1])
 | 
			
		||||
			} else {
 | 
			
		||||
				abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
 | 
			
		||||
				AbortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,16 +32,16 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		if ok {
 | 
			
		||||
			id, err := strconv.Atoi(channelId.(string))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
 | 
			
		||||
				AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.GetChannelById(id, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
 | 
			
		||||
				AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != model.ChannelStatusEnabled {
 | 
			
		||||
				abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
 | 
			
		||||
				AbortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		} else {
 | 
			
		||||
@@ -54,7 +54,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
 | 
			
		||||
					message = "Database consistency has been broken, please contact the administrator"
 | 
			
		||||
				}
 | 
			
		||||
				abortWithError(c, http.StatusServiceUnavailable, errors.New(message))
 | 
			
		||||
				AbortWithError(c, http.StatusServiceUnavailable, errors.New(message))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,8 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
 | 
			
		||||
	logger.Error(c.Request.Context(), message)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func abortWithError(c *gin.Context, statusCode int, err error) {
 | 
			
		||||
// AbortWithError aborts the request with an error message
 | 
			
		||||
func AbortWithError(c *gin.Context, statusCode int, err error) {
 | 
			
		||||
	logger := gmw.GetLogger(c)
 | 
			
		||||
	logger.Error("server abort", zap.Error(err))
 | 
			
		||||
	c.JSON(statusCode, gin.H{
 | 
			
		||||
 
 | 
			
		||||
@@ -4,11 +4,13 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
 | 
			
		||||
	gutils "github.com/Laisky/go-utils/v5"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/utils"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Ability struct {
 | 
			
		||||
@@ -110,3 +112,44 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) {
 | 
			
		||||
	sort.Strings(models)
 | 
			
		||||
	return models, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var getGroupModelsV2Cache = gutils.NewExpCache[[]EnabledAbility](context.Background(), time.Second*10)
 | 
			
		||||
 | 
			
		||||
type EnabledAbility struct {
 | 
			
		||||
	Model       string `json:"model" gorm:"model"`
 | 
			
		||||
	ChannelType int    `json:"channel_type" gorm:"channel_type"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetGroupModelsV2 returns all enabled models for this group with their channel names.
 | 
			
		||||
func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) {
 | 
			
		||||
	// get from cache first
 | 
			
		||||
	if models, ok := getGroupModelsV2Cache.Load(group); ok {
 | 
			
		||||
		return models, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// prepare query based on database type
 | 
			
		||||
	groupCol := "`group`"
 | 
			
		||||
	trueVal := "1"
 | 
			
		||||
	if common.UsingPostgreSQL {
 | 
			
		||||
		groupCol = `"group"`
 | 
			
		||||
		trueVal = "true"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// query with JOIN to get model and channel name in a single query
 | 
			
		||||
	var models []EnabledAbility
 | 
			
		||||
	query := DB.Model(&Ability{}).
 | 
			
		||||
		Select("abilities.model AS model, channels.type AS channel_type").
 | 
			
		||||
		Joins("JOIN channels ON abilities.channel_id = channels.id").
 | 
			
		||||
		Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group).
 | 
			
		||||
		Order("abilities.priority DESC")
 | 
			
		||||
 | 
			
		||||
	err := query.Find(&models).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Wrap(err, "get group models")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// store in cache
 | 
			
		||||
	getGroupModelsV2Cache.Store(group, models)
 | 
			
		||||
 | 
			
		||||
	return models, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,17 +4,18 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/pkg/errors"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -118,6 +119,11 @@ func InitDB() {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.DebugSQLEnabled {
 | 
			
		||||
		logger.Debug(context.TODO(), "debug sql enabled")
 | 
			
		||||
		DB = DB.Debug()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlDB := setDBConns(DB)
 | 
			
		||||
 | 
			
		||||
	if !config.IsMasterNode {
 | 
			
		||||
@@ -203,10 +209,6 @@ func migrateLOGDB() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setDBConns(db *gorm.DB) *sql.DB {
 | 
			
		||||
	if config.DebugSQLEnabled {
 | 
			
		||||
		db = db.Debug()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlDB, err := db.DB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.FatalLog("failed to connect database: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user