mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 07:43:41 +08:00 
			
		
		
		
	Compare commits
	
		
			9 Commits
		
	
	
		
			v0.5.0-alp
			...
			v0.5.1-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					b9f6461dd4 | ||
| 
						 | 
					0a39521a3d | ||
| 
						 | 
					c134604cee | ||
| 
						 | 
					929e43ef81 | ||
| 
						 | 
					dce8bbe1ca | ||
| 
						 | 
					bc2f48b1f2 | ||
| 
						 | 
					889af8b2db | ||
| 
						 | 
					4eea096654 | ||
| 
						 | 
					4ab3211c0e | 
@@ -77,6 +77,8 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
 | 
				
			|||||||
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
					var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
 | 
				
			||||||
var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
					var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	RoleGuestUser  = 0
 | 
						RoleGuestUser  = 0
 | 
				
			||||||
	RoleCommonUser = 1
 | 
						RoleCommonUser = 1
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,14 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
					func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
				
			||||||
	switch channel.Type {
 | 
						switch channel.Type {
 | 
				
			||||||
 | 
						case common.ChannelTypePaLM:
 | 
				
			||||||
 | 
							fallthrough
 | 
				
			||||||
 | 
						case common.ChannelTypeAnthropic:
 | 
				
			||||||
 | 
							fallthrough
 | 
				
			||||||
 | 
						case common.ChannelTypeBaidu:
 | 
				
			||||||
 | 
							fallthrough
 | 
				
			||||||
 | 
						case common.ChannelTypeZhipu:
 | 
				
			||||||
 | 
							return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
				
			||||||
	case common.ChannelTypeAzure:
 | 
						case common.ChannelTypeAzure:
 | 
				
			||||||
		request.Model = "gpt-35-turbo"
 | 
							request.Model = "gpt-35-turbo"
 | 
				
			||||||
	default:
 | 
						default:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -57,11 +57,22 @@ type BaiduChatStreamResponse struct {
 | 
				
			|||||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
					func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
				
			||||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
						messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
				
			||||||
	for _, message := range request.Messages {
 | 
						for _, message := range request.Messages {
 | 
				
			||||||
 | 
							if message.Role == "system" {
 | 
				
			||||||
 | 
								messages = append(messages, BaiduMessage{
 | 
				
			||||||
 | 
									Role:    "user",
 | 
				
			||||||
 | 
									Content: message.Content,
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
								messages = append(messages, BaiduMessage{
 | 
				
			||||||
 | 
									Role:    "assistant",
 | 
				
			||||||
 | 
									Content: "Okay",
 | 
				
			||||||
 | 
								})
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
			messages = append(messages, BaiduMessage{
 | 
								messages = append(messages, BaiduMessage{
 | 
				
			||||||
				Role:    message.Role,
 | 
									Role:    message.Role,
 | 
				
			||||||
				Content: message.Content,
 | 
									Content: message.Content,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	return &BaiduChatRequest{
 | 
						return &BaiduChatRequest{
 | 
				
			||||||
		Messages: messages,
 | 
							Messages: messages,
 | 
				
			||||||
		Stream:   request.Stream,
 | 
							Stream:   request.Stream,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -69,11 +69,11 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
 | 
				
			|||||||
			prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
 | 
								prompt += fmt.Sprintf("\n\nHuman: %s", message.Content)
 | 
				
			||||||
		} else if message.Role == "assistant" {
 | 
							} else if message.Role == "assistant" {
 | 
				
			||||||
			prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
								prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
 | 
				
			||||||
		} else {
 | 
							} else if message.Role == "system" {
 | 
				
			||||||
			// ignore other roles
 | 
								prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	prompt += "\n\nAssistant:"
 | 
						prompt += "\n\nAssistant:"
 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	claudeRequest.Prompt = prompt
 | 
						claudeRequest.Prompt = prompt
 | 
				
			||||||
	return &claudeRequest
 | 
						return &claudeRequest
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -85,13 +85,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	apiType := APITypeOpenAI
 | 
						apiType := APITypeOpenAI
 | 
				
			||||||
	if strings.HasPrefix(textRequest.Model, "claude") {
 | 
						switch channelType {
 | 
				
			||||||
 | 
						case common.ChannelTypeAnthropic:
 | 
				
			||||||
		apiType = APITypeClaude
 | 
							apiType = APITypeClaude
 | 
				
			||||||
	} else if strings.HasPrefix(textRequest.Model, "ERNIE") {
 | 
						case common.ChannelTypeBaidu:
 | 
				
			||||||
		apiType = APITypeBaidu
 | 
							apiType = APITypeBaidu
 | 
				
			||||||
	} else if strings.HasPrefix(textRequest.Model, "PaLM") {
 | 
						case common.ChannelTypePaLM:
 | 
				
			||||||
		apiType = APITypePaLM
 | 
							apiType = APITypePaLM
 | 
				
			||||||
	} else if strings.HasPrefix(textRequest.Model, "chatglm_") {
 | 
						case common.ChannelTypeZhipu:
 | 
				
			||||||
		apiType = APITypeZhipu
 | 
							apiType = APITypeZhipu
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
						baseURL := common.ChannelBaseURLs[channelType]
 | 
				
			||||||
@@ -140,6 +141,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
				
			|||||||
		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 | 
							fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 | 
				
			||||||
	case APITypePaLM:
 | 
						case APITypePaLM:
 | 
				
			||||||
		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 | 
							fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 | 
				
			||||||
 | 
							if baseURL != "" {
 | 
				
			||||||
 | 
								fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
							apiKey := c.Request.Header.Get("Authorization")
 | 
				
			||||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
							apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
				
			||||||
		fullRequestURL += "?key=" + apiKey
 | 
							fullRequestURL += "?key=" + apiKey
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										1
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								main.go
									
									
									
									
									
								
							@@ -54,6 +54,7 @@ func main() {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
 | 
								common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
							common.SyncFrequency = frequency
 | 
				
			||||||
		go model.SyncOptions(frequency)
 | 
							go model.SyncOptions(frequency)
 | 
				
			||||||
		if common.RedisEnabled {
 | 
							if common.RedisEnabled {
 | 
				
			||||||
			go model.SyncChannelCache(frequency)
 | 
								go model.SyncChannelCache(frequency)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,11 +12,11 @@ import (
 | 
				
			|||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					var (
 | 
				
			||||||
	TokenCacheSeconds         = 60 * 60
 | 
						TokenCacheSeconds         = common.SyncFrequency
 | 
				
			||||||
	UserId2GroupCacheSeconds  = 60 * 60
 | 
						UserId2GroupCacheSeconds  = common.SyncFrequency
 | 
				
			||||||
	UserId2QuotaCacheSeconds  = 10 * 60
 | 
						UserId2QuotaCacheSeconds  = common.SyncFrequency
 | 
				
			||||||
	UserId2StatusCacheSeconds = 60 * 60
 | 
						UserId2StatusCacheSeconds = common.SyncFrequency
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func CacheGetTokenByKey(key string) (*Token, error) {
 | 
					func CacheGetTokenByKey(key string) (*Token, error) {
 | 
				
			||||||
@@ -35,7 +35,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), TokenCacheSeconds*time.Second)
 | 
							err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			common.SysError("Redis set token error: " + err.Error())
 | 
								common.SysError("Redis set token error: " + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -55,7 +55,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return "", err
 | 
								return "", err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, UserId2GroupCacheSeconds*time.Second)
 | 
							err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			common.SysError("Redis set user group error: " + err.Error())
 | 
								common.SysError("Redis set user group error: " + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -73,7 +73,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
 | 
				
			|||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return 0, err
 | 
								return 0, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
 | 
							err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			common.SysError("Redis set user quota error: " + err.Error())
 | 
								common.SysError("Redis set user quota error: " + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
@@ -91,7 +91,7 @@ func CacheUpdateUserQuota(id int) error {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), UserId2QuotaCacheSeconds*time.Second)
 | 
						err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
 | 
				
			||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -106,7 +106,7 @@ func CacheIsUserEnabled(userId int) bool {
 | 
				
			|||||||
			status = common.UserStatusEnabled
 | 
								status = common.UserStatusEnabled
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		enabled = fmt.Sprintf("%d", status)
 | 
							enabled = fmt.Sprintf("%d", status)
 | 
				
			||||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, UserId2StatusCacheSeconds*time.Second)
 | 
							err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			common.SysError("Redis set user enabled error: " + err.Error())
 | 
								common.SysError("Redis set user enabled error: " + err.Error())
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -51,20 +51,21 @@ func Redeem(key string, userId int) (quota int, err error) {
 | 
				
			|||||||
	redemption := &Redemption{}
 | 
						redemption := &Redemption{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err = DB.Transaction(func(tx *gorm.DB) error {
 | 
						err = DB.Transaction(func(tx *gorm.DB) error {
 | 
				
			||||||
		err := DB.Where("`key` = ?", key).First(redemption).Error
 | 
							err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return errors.New("无效的兑换码")
 | 
								return errors.New("无效的兑换码")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		if redemption.Status != common.RedemptionCodeStatusEnabled {
 | 
							if redemption.Status != common.RedemptionCodeStatusEnabled {
 | 
				
			||||||
			return errors.New("该兑换码已被使用")
 | 
								return errors.New("该兑换码已被使用")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		err = DB.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
							err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		redemption.RedeemedTime = common.GetTimestamp()
 | 
							redemption.RedeemedTime = common.GetTimestamp()
 | 
				
			||||||
		redemption.Status = common.RedemptionCodeStatusUsed
 | 
							redemption.Status = common.RedemptionCodeStatusUsed
 | 
				
			||||||
		return redemption.SelectUpdate()
 | 
							err = tx.Save(redemption).Error
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return 0, errors.New("兑换失败," + err.Error())
 | 
							return 0, errors.New("兑换失败," + err.Error())
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,7 +12,7 @@ func SetRelayRouter(router *gin.Engine) {
 | 
				
			|||||||
	modelsRouter := router.Group("/v1/models")
 | 
						modelsRouter := router.Group("/v1/models")
 | 
				
			||||||
	modelsRouter.Use(middleware.TokenAuth())
 | 
						modelsRouter.Use(middleware.TokenAuth())
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		modelsRouter.GET("/", controller.ListModels)
 | 
							modelsRouter.GET("", controller.ListModels)
 | 
				
			||||||
		modelsRouter.GET("/:model", controller.RetrieveModel)
 | 
							modelsRouter.GET("/:model", controller.RetrieveModel)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	relayV1Router := router.Group("/v1")
 | 
						relayV1Router := router.Group("/v1")
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -288,7 +288,7 @@ const PersonalSetting = () => {
 | 
				
			|||||||
            <Form size='large'>
 | 
					            <Form size='large'>
 | 
				
			||||||
              <Form.Input
 | 
					              <Form.Input
 | 
				
			||||||
                fluid
 | 
					                fluid
 | 
				
			||||||
                placeholder={`输入你的账户名 ${userState.user.username} 以确认删除`}
 | 
					                placeholder={`输入你的账户名 ${userState?.user?.username} 以确认删除`}
 | 
				
			||||||
                name='self_account_deletion_confirmation'
 | 
					                name='self_account_deletion_confirmation'
 | 
				
			||||||
                value={inputs.self_account_deletion_confirmation}
 | 
					                value={inputs.self_account_deletion_confirmation}
 | 
				
			||||||
                onChange={handleInputChange}
 | 
					                onChange={handleInputChange}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user