mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	chore: reorganize common package
This commit is contained in:
		
							
								
								
									
										9
									
								
								common/config/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								common/config/key.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package config
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	KeyPrefix = "cfg_"
 | 
			
		||||
 | 
			
		||||
	KeyAPIVersion = KeyPrefix + "api_version"
 | 
			
		||||
	KeyLibraryID  = KeyPrefix + "library_id"
 | 
			
		||||
	KeyPlugin     = KeyPrefix + "plugin"
 | 
			
		||||
)
 | 
			
		||||
@@ -4,80 +4,3 @@ import "time"
 | 
			
		||||
 | 
			
		||||
var StartTime = time.Now().Unix() // unit: second
 | 
			
		||||
var Version = "v0.0.0"            // this hard coding will be replaced automatically when building, no need to manually change
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
	RoleAdminUser  = 10
 | 
			
		||||
	RoleRootUser   = 100
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	UserStatusEnabled  = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	UserStatusDisabled = 2 // also don't use 0
 | 
			
		||||
	UserStatusDeleted  = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	TokenStatusDisabled  = 2 // also don't use 0
 | 
			
		||||
	TokenStatusExpired   = 3
 | 
			
		||||
	TokenStatusExhausted = 4
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	RedemptionCodeStatusDisabled = 2 // also don't use 0
 | 
			
		||||
	RedemptionCodeStatusUsed     = 3 // also don't use 0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ChannelStatusUnknown          = 0
 | 
			
		||||
	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	ChannelStatusManuallyDisabled = 2 // also don't use 0
 | 
			
		||||
	ChannelStatusAutoDisabled     = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                              // 0
 | 
			
		||||
	"https://api.openai.com",        // 1
 | 
			
		||||
	"https://oa.api2d.net",          // 2
 | 
			
		||||
	"",                              // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz", // 4
 | 
			
		||||
	"https://api.openai-sb.com",     // 5
 | 
			
		||||
	"https://api.openaimax.com",     // 6
 | 
			
		||||
	"https://api.ohmygpt.com",       // 7
 | 
			
		||||
	"",                              // 8
 | 
			
		||||
	"https://api.caipacity.com",     // 9
 | 
			
		||||
	"https://api.aiproxy.io",        // 10
 | 
			
		||||
	"https://generativelanguage.googleapis.com", // 11
 | 
			
		||||
	"https://api.api2gpt.com",                   // 12
 | 
			
		||||
	"https://api.aigc2d.com",                    // 13
 | 
			
		||||
	"https://api.anthropic.com",                 // 14
 | 
			
		||||
	"https://aip.baidubce.com",                  // 15
 | 
			
		||||
	"https://open.bigmodel.cn",                  // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com",            // 17
 | 
			
		||||
	"",                                          // 18
 | 
			
		||||
	"https://ai.360.cn",                         // 19
 | 
			
		||||
	"https://openrouter.ai/api",                 // 20
 | 
			
		||||
	"https://api.aiproxy.io",                    // 21
 | 
			
		||||
	"https://fastgpt.run/api/openapi",           // 22
 | 
			
		||||
	"https://hunyuan.cloud.tencent.com",         // 23
 | 
			
		||||
	"https://generativelanguage.googleapis.com", // 24
 | 
			
		||||
	"https://api.moonshot.cn",                   // 25
 | 
			
		||||
	"https://api.baichuan-ai.com",               // 26
 | 
			
		||||
	"https://api.minimax.chat",                  // 27
 | 
			
		||||
	"https://api.mistral.ai",                    // 28
 | 
			
		||||
	"https://api.groq.com/openai",               // 29
 | 
			
		||||
	"http://localhost:11434",                    // 30
 | 
			
		||||
	"https://api.lingyiwanwu.com",               // 31
 | 
			
		||||
	"https://api.stepfun.com",                   // 32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ConfigKeyPrefix = "cfg_"
 | 
			
		||||
 | 
			
		||||
	ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
 | 
			
		||||
	ConfigKeyLibraryID  = ConfigKeyPrefix + "library_id"
 | 
			
		||||
	ConfigKeyPlugin     = ConfigKeyPrefix + "plugin"
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
@@ -134,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
				user.DisplayName = "GitHub User"
 | 
			
		||||
			}
 | 
			
		||||
			user.Email = githubUser.Email
 | 
			
		||||
			user.Role = common.RoleCommonUser
 | 
			
		||||
			user.Status = common.UserStatusEnabled
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -153,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status != common.UserStatusEnabled {
 | 
			
		||||
	if user.Status != model.UserStatusEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
			"success": false,
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
@@ -123,8 +122,8 @@ func LarkOAuth(c *gin.Context) {
 | 
			
		||||
			} else {
 | 
			
		||||
				user.DisplayName = "Lark User"
 | 
			
		||||
			}
 | 
			
		||||
			user.Role = common.RoleCommonUser
 | 
			
		||||
			user.Status = common.UserStatusEnabled
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -142,7 +141,7 @@ func LarkOAuth(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status != common.UserStatusEnabled {
 | 
			
		||||
	if user.Status != model.UserStatusEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
			"success": false,
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/controller"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -84,8 +83,8 @@ func WeChatAuth(c *gin.Context) {
 | 
			
		||||
		if config.RegisterEnabled {
 | 
			
		||||
			user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
 | 
			
		||||
			user.DisplayName = "WeChat User"
 | 
			
		||||
			user.Role = common.RoleCommonUser
 | 
			
		||||
			user.Status = common.UserStatusEnabled
 | 
			
		||||
			user.Role = model.RoleCommonUser
 | 
			
		||||
			user.Status = model.UserStatusEnabled
 | 
			
		||||
 | 
			
		||||
			if err := user.Insert(0); err != nil {
 | 
			
		||||
				c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
@@ -103,7 +102,7 @@ func WeChatAuth(c *gin.Context) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status != common.UserStatusEnabled {
 | 
			
		||||
	if user.Status != model.UserStatusEnabled {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
			"success": false,
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -205,7 +204,7 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	baseURL := channeltype.ChannelBaseURLs[channel.Type]
 | 
			
		||||
	if channel.GetBaseURL() == "" {
 | 
			
		||||
		channel.BaseURL = &baseURL
 | 
			
		||||
	}
 | 
			
		||||
@@ -302,7 +301,7 @@ func updateAllChannelsBalance() error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, channel := range channels {
 | 
			
		||||
		if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
		if channel.Status != model.ChannelStatusEnabled {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		// TODO: support Azure
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/message"
 | 
			
		||||
@@ -173,7 +172,7 @@ func testChannels(notify bool, scope string) error {
 | 
			
		||||
	}
 | 
			
		||||
	go func() {
 | 
			
		||||
		for _, channel := range channels {
 | 
			
		||||
			isChannelEnabled := channel.Status == common.ChannelStatusEnabled
 | 
			
		||||
			isChannelEnabled := channel.Status == model.ChannelStatusEnabled
 | 
			
		||||
			tik := time.Now()
 | 
			
		||||
			err, openaiErr := testChannel(channel)
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,6 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
@@ -213,15 +212,15 @@ func UpdateToken(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if token.Status == common.TokenStatusEnabled {
 | 
			
		||||
		if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
 | 
			
		||||
	if token.Status == model.TokenStatusEnabled {
 | 
			
		||||
		if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
 | 
			
		||||
		if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
 | 
			
		||||
 
 | 
			
		||||
@@ -239,7 +239,7 @@ func GetUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	if myRole <= user.Role && myRole != common.RoleRootUser {
 | 
			
		||||
	if myRole <= user.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权获取同级或更高等级用户的信息",
 | 
			
		||||
@@ -388,14 +388,14 @@ func UpdateUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	if myRole <= originUser.Role && myRole != common.RoleRootUser {
 | 
			
		||||
	if myRole <= originUser.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权更新同权限等级或更高权限等级的用户信息",
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
 | 
			
		||||
	if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
 | 
			
		||||
@@ -509,7 +509,7 @@ func DeleteSelf(c *gin.Context) {
 | 
			
		||||
	id := c.GetInt("id")
 | 
			
		||||
	user, _ := model.GetUserById(id, false)
 | 
			
		||||
 | 
			
		||||
	if user.Role == common.RoleRootUser {
 | 
			
		||||
	if user.Role == model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "不能删除超级管理员账户",
 | 
			
		||||
@@ -611,7 +611,7 @@ func ManageUser(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	myRole := c.GetInt("role")
 | 
			
		||||
	if myRole <= user.Role && myRole != common.RoleRootUser {
 | 
			
		||||
	if myRole <= user.Role && myRole != model.RoleRootUser {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "无权更新同权限等级或更高权限等级的用户信息",
 | 
			
		||||
@@ -620,8 +620,8 @@ func ManageUser(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	switch req.Action {
 | 
			
		||||
	case "disable":
 | 
			
		||||
		user.Status = common.UserStatusDisabled
 | 
			
		||||
		if user.Role == common.RoleRootUser {
 | 
			
		||||
		user.Status = model.UserStatusDisabled
 | 
			
		||||
		if user.Role == model.RoleRootUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "无法禁用超级管理员用户",
 | 
			
		||||
@@ -629,9 +629,9 @@ func ManageUser(c *gin.Context) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case "enable":
 | 
			
		||||
		user.Status = common.UserStatusEnabled
 | 
			
		||||
		user.Status = model.UserStatusEnabled
 | 
			
		||||
	case "delete":
 | 
			
		||||
		if user.Role == common.RoleRootUser {
 | 
			
		||||
		if user.Role == model.RoleRootUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "无法删除超级管理员用户",
 | 
			
		||||
@@ -646,37 +646,37 @@ func ManageUser(c *gin.Context) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	case "promote":
 | 
			
		||||
		if myRole != common.RoleRootUser {
 | 
			
		||||
		if myRole != model.RoleRootUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "普通管理员用户无法提升其他用户为管理员",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if user.Role >= common.RoleAdminUser {
 | 
			
		||||
		if user.Role >= model.RoleAdminUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "该用户已经是管理员",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		user.Role = common.RoleAdminUser
 | 
			
		||||
		user.Role = model.RoleAdminUser
 | 
			
		||||
	case "demote":
 | 
			
		||||
		if user.Role == common.RoleRootUser {
 | 
			
		||||
		if user.Role == model.RoleRootUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "无法降级超级管理员用户",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if user.Role == common.RoleCommonUser {
 | 
			
		||||
		if user.Role == model.RoleCommonUser {
 | 
			
		||||
			c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
				"success": false,
 | 
			
		||||
				"message": "该用户已经是普通用户",
 | 
			
		||||
			})
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		user.Role = common.RoleCommonUser
 | 
			
		||||
		user.Role = model.RoleCommonUser
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := user.Update(false); err != nil {
 | 
			
		||||
@@ -730,7 +730,7 @@ func EmailBind(c *gin.Context) {
 | 
			
		||||
		})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if user.Role == common.RoleRootUser {
 | 
			
		||||
	if user.Role == model.RoleRootUser {
 | 
			
		||||
		config.RootUserEmail = email
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/blacklist"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/network"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
@@ -45,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
 | 
			
		||||
	if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
 | 
			
		||||
		c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
			"success": false,
 | 
			
		||||
			"message": "用户已被封禁",
 | 
			
		||||
@@ -72,19 +71,19 @@ func authHelper(c *gin.Context, minRole int) {
 | 
			
		||||
 | 
			
		||||
func UserAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		authHelper(c, common.RoleCommonUser)
 | 
			
		||||
		authHelper(c, model.RoleCommonUser)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func AdminAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		authHelper(c, common.RoleAdminUser)
 | 
			
		||||
		authHelper(c, model.RoleAdminUser)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RootAuth() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		authHelper(c, common.RoleRootUser)
 | 
			
		||||
		authHelper(c, model.RoleRootUser)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ package middleware
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/model"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
@@ -34,7 +34,7 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
				abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if channel.Status != common.ChannelStatusEnabled {
 | 
			
		||||
			if channel.Status != model.ChannelStatusEnabled {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
@@ -68,18 +68,18 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
 | 
			
		||||
	// this is for backward compatibility
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case channeltype.Azure:
 | 
			
		||||
		c.Set(common.ConfigKeyAPIVersion, channel.Other)
 | 
			
		||||
		c.Set(config.KeyAPIVersion, channel.Other)
 | 
			
		||||
	case channeltype.Xunfei:
 | 
			
		||||
		c.Set(common.ConfigKeyAPIVersion, channel.Other)
 | 
			
		||||
		c.Set(config.KeyAPIVersion, channel.Other)
 | 
			
		||||
	case channeltype.Gemini:
 | 
			
		||||
		c.Set(common.ConfigKeyAPIVersion, channel.Other)
 | 
			
		||||
		c.Set(config.KeyAPIVersion, channel.Other)
 | 
			
		||||
	case channeltype.AIProxyLibrary:
 | 
			
		||||
		c.Set(common.ConfigKeyLibraryID, channel.Other)
 | 
			
		||||
		c.Set(config.KeyLibraryID, channel.Other)
 | 
			
		||||
	case channeltype.Ali:
 | 
			
		||||
		c.Set(common.ConfigKeyPlugin, channel.Other)
 | 
			
		||||
		c.Set(config.KeyPlugin, channel.Other)
 | 
			
		||||
	}
 | 
			
		||||
	cfg, _ := channel.LoadConfig()
 | 
			
		||||
	for k, v := range cfg {
 | 
			
		||||
		c.Set(common.ConfigKeyPrefix+k, v)
 | 
			
		||||
		c.Set(config.KeyPrefix+k, v)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -57,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
 | 
			
		||||
				Group:     group,
 | 
			
		||||
				Model:     model,
 | 
			
		||||
				ChannelId: channel.Id,
 | 
			
		||||
				Enabled:   channel.Status == common.ChannelStatusEnabled,
 | 
			
		||||
				Enabled:   channel.Status == ChannelStatusEnabled,
 | 
			
		||||
				Priority:  channel.Priority,
 | 
			
		||||
			}
 | 
			
		||||
			abilities = append(abilities, ability)
 | 
			
		||||
 
 | 
			
		||||
@@ -173,7 +173,7 @@ var channelSyncLock sync.RWMutex
 | 
			
		||||
func InitChannelCache() {
 | 
			
		||||
	newChannelId2channel := make(map[int]*Channel)
 | 
			
		||||
	var channels []*Channel
 | 
			
		||||
	DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
 | 
			
		||||
	DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
 | 
			
		||||
	for _, channel := range channels {
 | 
			
		||||
		newChannelId2channel[channel.Id] = channel
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,13 +3,19 @@ package model
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ChannelStatusUnknown          = 0
 | 
			
		||||
	ChannelStatusEnabled          = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	ChannelStatusManuallyDisabled = 2 // also don't use 0
 | 
			
		||||
	ChannelStatusAutoDisabled     = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Channel struct {
 | 
			
		||||
	Id                 int     `json:"id"`
 | 
			
		||||
	Type               int     `json:"type" gorm:"default:0"`
 | 
			
		||||
@@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
 | 
			
		||||
	case "all":
 | 
			
		||||
		err = DB.Order("id desc").Find(&channels).Error
 | 
			
		||||
	case "disabled":
 | 
			
		||||
		err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
 | 
			
		||||
		err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
 | 
			
		||||
	default:
 | 
			
		||||
		err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
 | 
			
		||||
	}
 | 
			
		||||
@@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelStatusById(id int, status int) {
 | 
			
		||||
	err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
 | 
			
		||||
	err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.SysError("failed to update ability status: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
@@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DeleteDisabledChannel() (int64, error) {
 | 
			
		||||
	result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
 | 
			
		||||
	result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
 | 
			
		||||
	return result.RowsAffected, result.Error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,8 +32,8 @@ func CreateRootAccountIfNeed() error {
 | 
			
		||||
		rootUser := User{
 | 
			
		||||
			Username:    "root",
 | 
			
		||||
			Password:    hashedPassword,
 | 
			
		||||
			Role:        common.RoleRootUser,
 | 
			
		||||
			Status:      common.UserStatusEnabled,
 | 
			
		||||
			Role:        RoleRootUser,
 | 
			
		||||
			Status:      UserStatusEnabled,
 | 
			
		||||
			DisplayName: "Root User",
 | 
			
		||||
			AccessToken: random.GetUUID(),
 | 
			
		||||
			Quota:       500000000000000,
 | 
			
		||||
@@ -45,7 +45,7 @@ func CreateRootAccountIfNeed() error {
 | 
			
		||||
				Id:             1,
 | 
			
		||||
				UserId:         rootUser.Id,
 | 
			
		||||
				Key:            config.InitialRootToken,
 | 
			
		||||
				Status:         common.TokenStatusEnabled,
 | 
			
		||||
				Status:         TokenStatusEnabled,
 | 
			
		||||
				Name:           "Initial Root Token",
 | 
			
		||||
				CreatedTime:    helper.GetTimestamp(),
 | 
			
		||||
				AccessedTime:   helper.GetTimestamp(),
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,12 @@ import (
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RedemptionCodeStatusEnabled  = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	RedemptionCodeStatusDisabled = 2 // also don't use 0
 | 
			
		||||
	RedemptionCodeStatusUsed     = 3 // also don't use 0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Redemption struct {
 | 
			
		||||
	Id           int    `json:"id"`
 | 
			
		||||
	UserId       int    `json:"user_id"`
 | 
			
		||||
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errors.New("无效的兑换码")
 | 
			
		||||
		}
 | 
			
		||||
		if redemption.Status != common.RedemptionCodeStatusEnabled {
 | 
			
		||||
		if redemption.Status != RedemptionCodeStatusEnabled {
 | 
			
		||||
			return errors.New("该兑换码已被使用")
 | 
			
		||||
		}
 | 
			
		||||
		err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
			
		||||
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		redemption.RedeemedTime = helper.GetTimestamp()
 | 
			
		||||
		redemption.Status = common.RedemptionCodeStatusUsed
 | 
			
		||||
		redemption.Status = RedemptionCodeStatusUsed
 | 
			
		||||
		err = tx.Save(redemption).Error
 | 
			
		||||
		return err
 | 
			
		||||
	})
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,13 @@ import (
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TokenStatusEnabled   = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	TokenStatusDisabled  = 2 // also don't use 0
 | 
			
		||||
	TokenStatusExpired   = 3
 | 
			
		||||
	TokenStatusExhausted = 4
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Token struct {
 | 
			
		||||
	Id             int     `json:"id"`
 | 
			
		||||
	UserId         int     `json:"user_id"`
 | 
			
		||||
@@ -62,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
		}
 | 
			
		||||
		return nil, errors.New("令牌验证失败")
 | 
			
		||||
	}
 | 
			
		||||
	if token.Status == common.TokenStatusExhausted {
 | 
			
		||||
	if token.Status == TokenStatusExhausted {
 | 
			
		||||
		return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id)
 | 
			
		||||
	} else if token.Status == common.TokenStatusExpired {
 | 
			
		||||
	} else if token.Status == TokenStatusExpired {
 | 
			
		||||
		return nil, errors.New("该令牌已过期")
 | 
			
		||||
	}
 | 
			
		||||
	if token.Status != common.TokenStatusEnabled {
 | 
			
		||||
	if token.Status != TokenStatusEnabled {
 | 
			
		||||
		return nil, errors.New("该令牌状态不可用")
 | 
			
		||||
	}
 | 
			
		||||
	if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
 | 
			
		||||
		if !common.RedisEnabled {
 | 
			
		||||
			token.Status = common.TokenStatusExpired
 | 
			
		||||
			token.Status = TokenStatusExpired
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("failed to update token status" + err.Error())
 | 
			
		||||
@@ -83,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
	if !token.UnlimitedQuota && token.RemainQuota <= 0 {
 | 
			
		||||
		if !common.RedisEnabled {
 | 
			
		||||
			// in this case, we can make sure the token is exhausted
 | 
			
		||||
			token.Status = common.TokenStatusExhausted
 | 
			
		||||
			token.Status = TokenStatusExhausted
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("failed to update token status" + err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,19 @@ import (
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
	RoleAdminUser  = 10
 | 
			
		||||
	RoleRootUser   = 100
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	UserStatusEnabled  = 1 // don't use 0, 0 is the default value!
 | 
			
		||||
	UserStatusDisabled = 2 // also don't use 0
 | 
			
		||||
	UserStatusDeleted  = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
 | 
			
		||||
// Otherwise, the sensitive information will be saved on local storage in plain text!
 | 
			
		||||
type User struct {
 | 
			
		||||
@@ -42,7 +55,7 @@ func GetMaxUserId() int {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
 | 
			
		||||
	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
 | 
			
		||||
	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
 | 
			
		||||
 | 
			
		||||
	switch order {
 | 
			
		||||
	case "quota":
 | 
			
		||||
@@ -138,9 +151,9 @@ func (user *User) Update(updatePassword bool) error {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if user.Status == common.UserStatusDisabled {
 | 
			
		||||
	if user.Status == UserStatusDisabled {
 | 
			
		||||
		blacklist.BanUser(user.Id)
 | 
			
		||||
	} else if user.Status == common.UserStatusEnabled {
 | 
			
		||||
	} else if user.Status == UserStatusEnabled {
 | 
			
		||||
		blacklist.UnbanUser(user.Id)
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Model(user).Updates(user).Error
 | 
			
		||||
@@ -153,7 +166,7 @@ func (user *User) Delete() error {
 | 
			
		||||
	}
 | 
			
		||||
	blacklist.BanUser(user.Id)
 | 
			
		||||
	user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
 | 
			
		||||
	user.Status = common.UserStatusDeleted
 | 
			
		||||
	user.Status = UserStatusDeleted
 | 
			
		||||
	err := DB.Model(user).Updates(user).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
@@ -177,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	okay := common.ValidatePasswordAndHash(password, user.Password)
 | 
			
		||||
	if !okay || user.Status != common.UserStatusEnabled {
 | 
			
		||||
	if !okay || user.Status != UserStatusEnabled {
 | 
			
		||||
		return errors.New("用户名或密码错误,或用户已被封禁")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
@@ -273,7 +286,7 @@ func IsAdmin(userId int) bool {
 | 
			
		||||
		logger.SysError("no such user " + err.Error())
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return user.Role >= common.RoleAdminUser
 | 
			
		||||
	return user.Role >= RoleAdminUser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
@@ -285,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return user.Status == common.UserStatusEnabled, nil
 | 
			
		||||
	return user.Status == UserStatusEnabled, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ValidateAccessToken(token string) (user *User) {
 | 
			
		||||
@@ -358,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetRootUserEmail() (email string) {
 | 
			
		||||
	DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
 | 
			
		||||
	DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
 | 
			
		||||
	return email
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package monitor
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/message"
 | 
			
		||||
@@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) {
 | 
			
		||||
 | 
			
		||||
// DisableChannel disable & notify
 | 
			
		||||
func DisableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
 | 
			
		||||
	logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
 | 
			
		||||
	subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | 
			
		||||
@@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func MetricDisableChannel(channelId int, successRate float64) {
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
 | 
			
		||||
	logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
 | 
			
		||||
	subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
 | 
			
		||||
	content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
 | 
			
		||||
@@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
 | 
			
		||||
 | 
			
		||||
// EnableChannel enable & notify
 | 
			
		||||
func EnableChannel(channelId int, channelName string) {
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
 | 
			
		||||
	model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
 | 
			
		||||
	logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
 | 
			
		||||
	subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
 | 
			
		||||
	content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
@@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	aiProxyLibraryRequest := ConvertRequest(*request)
 | 
			
		||||
	aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
 | 
			
		||||
	aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
 | 
			
		||||
	return aiProxyLibraryRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
@@ -47,8 +47,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
 | 
			
		||||
	if meta.Mode == relaymode.ImagesGenerations {
 | 
			
		||||
		req.Header.Set("X-DashScope-Async", "enable")
 | 
			
		||||
	}
 | 
			
		||||
	if c.GetString(common.ConfigKeyPlugin) != "" {
 | 
			
		||||
		req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
 | 
			
		||||
	if c.GetString(config.KeyPlugin) != "" {
 | 
			
		||||
		req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,14 +2,14 @@ package azure
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetAPIVersion(c *gin.Context) string {
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString(common.ConfigKeyAPIVersion)
 | 
			
		||||
		apiVersion = c.GetString(config.KeyAPIVersion)
 | 
			
		||||
	}
 | 
			
		||||
	return apiVersion
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,6 +9,7 @@ import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/helper"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/random"
 | 
			
		||||
@@ -279,7 +280,7 @@ func getAPIVersion(c *gin.Context, modelName string) string {
 | 
			
		||||
		return apiVersion
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	apiVersion = c.GetString(common.ConfigKeyAPIVersion)
 | 
			
		||||
	apiVersion = c.GetString(config.KeyAPIVersion)
 | 
			
		||||
	if apiVersion != "" {
 | 
			
		||||
		return apiVersion
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										43
									
								
								relay/channeltype/url.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								relay/channeltype/url.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
package channeltype
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                              // 0
 | 
			
		||||
	"https://api.openai.com",        // 1
 | 
			
		||||
	"https://oa.api2d.net",          // 2
 | 
			
		||||
	"",                              // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz", // 4
 | 
			
		||||
	"https://api.openai-sb.com",     // 5
 | 
			
		||||
	"https://api.openaimax.com",     // 6
 | 
			
		||||
	"https://api.ohmygpt.com",       // 7
 | 
			
		||||
	"",                              // 8
 | 
			
		||||
	"https://api.caipacity.com",     // 9
 | 
			
		||||
	"https://api.aiproxy.io",        // 10
 | 
			
		||||
	"https://generativelanguage.googleapis.com", // 11
 | 
			
		||||
	"https://api.api2gpt.com",                   // 12
 | 
			
		||||
	"https://api.aigc2d.com",                    // 13
 | 
			
		||||
	"https://api.anthropic.com",                 // 14
 | 
			
		||||
	"https://aip.baidubce.com",                  // 15
 | 
			
		||||
	"https://open.bigmodel.cn",                  // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com",            // 17
 | 
			
		||||
	"",                                          // 18
 | 
			
		||||
	"https://ai.360.cn",                         // 19
 | 
			
		||||
	"https://openrouter.ai/api",                 // 20
 | 
			
		||||
	"https://api.aiproxy.io",                    // 21
 | 
			
		||||
	"https://fastgpt.run/api/openapi",           // 22
 | 
			
		||||
	"https://hunyuan.cloud.tencent.com",         // 23
 | 
			
		||||
	"https://generativelanguage.googleapis.com", // 24
 | 
			
		||||
	"https://api.moonshot.cn",                   // 25
 | 
			
		||||
	"https://api.baichuan-ai.com",               // 26
 | 
			
		||||
	"https://api.minimax.chat",                  // 27
 | 
			
		||||
	"https://api.mistral.ai",                    // 28
 | 
			
		||||
	"https://api.groq.com/openai",               // 29
 | 
			
		||||
	"http://localhost:11434",                    // 30
 | 
			
		||||
	"https://api.lingyiwanwu.com",               // 31
 | 
			
		||||
	"https://api.stepfun.com",                   // 32
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	if len(ChannelBaseURLs) != Dummy {
 | 
			
		||||
		panic("channel base urls length not match")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -119,7 +119,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	baseURL := channeltype.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,7 @@ package meta
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/config"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/azure"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
			
		||||
@@ -41,7 +41,7 @@ func GetByContext(c *gin.Context) *Meta {
 | 
			
		||||
		Group:          c.GetString("group"),
 | 
			
		||||
		ModelMapping:   c.GetStringMapString("model_mapping"),
 | 
			
		||||
		BaseURL:        c.GetString("base_url"),
 | 
			
		||||
		APIVersion:     c.GetString(common.ConfigKeyAPIVersion),
 | 
			
		||||
		APIVersion:     c.GetString(config.KeyAPIVersion),
 | 
			
		||||
		APIKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
 | 
			
		||||
		Config:         nil,
 | 
			
		||||
		RequestURLPath: c.Request.URL.String(),
 | 
			
		||||
@@ -50,7 +50,7 @@ func GetByContext(c *gin.Context) *Meta {
 | 
			
		||||
		meta.APIVersion = azure.GetAPIVersion(c)
 | 
			
		||||
	}
 | 
			
		||||
	if meta.BaseURL == "" {
 | 
			
		||||
		meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
 | 
			
		||||
		meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
 | 
			
		||||
	}
 | 
			
		||||
	meta.APIType = channeltype.ToAPIType(meta.ChannelType)
 | 
			
		||||
	return &meta
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user