From 49cad7d4a542b92ccc3bfec491259f1b8894c107 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 13 Mar 2024 19:11:30 +0800 Subject: [PATCH 01/22] feat: update func ShouldDisableChannel for claude --- relay/util/common.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/relay/util/common.go b/relay/util/common.go index 20257488..dbd724b4 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -35,10 +35,17 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { return true case "permission_error": return true + case "forbidden": + return true } if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { return true } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } return false } From 0710f8cd66cfb2b13ad25b0fdf4212d74e5abf73 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 13 Mar 2024 19:26:24 +0800 Subject: [PATCH 02/22] fix: when cached quota is too low, force refresh it --- model/cache.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/model/cache.go b/model/cache.go index 3c3575b8..3a9f8023 100644 --- a/model/cache.go +++ b/model/cache.go @@ -70,23 +70,30 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } +func fetchAndUpdateUserQuota(id int) (quota int, err error) { + quota, err = GetUserQuota(id) + if err != nil { + return 0, err + } + err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set user quota error: " + err.Error()) + } + return +} + func CacheGetUserQuota(id int) (quota int, err error) { if !common.RedisEnabled { return GetUserQuota(id) } quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) if err != nil { - quota, err = GetUserQuota(id) - if err != nil { - return 0, err - } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) - if err != nil { - logger.SysError("Redis set user quota error: " + err.Error()) - } - return quota, err + return fetchAndUpdateUserQuota(id) } quota, err = strconv.Atoi(quotaString) + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + return fetchAndUpdateUserQuota(id) + } return quota, err } From a72e5fcc9e9d6fbabd53630c66971ecec8a53f47 Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 13 Mar 2024 19:38:44 +0800 Subject: [PATCH 03/22] fix: when cached quota is too low, force refresh it --- model/cache.go | 23 ++++++++++++++--------- relay/controller/audio.go | 3 ++- relay/controller/helper.go | 4 ++-- relay/controller/image.go | 4 ++-- relay/util/common.go | 2 +- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/model/cache.go b/model/cache.go index 3a9f8023..9fd29e7a 100644 --- a/model/cache.go +++ b/model/cache.go @@ -1,6 +1,7 @@ package model import ( + "context" "encoding/json" "errors" "fmt" @@ -70,38 +71,42 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -func fetchAndUpdateUserQuota(id int) (quota int, err error) { +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int, err error) { quota, err = GetUserQuota(id) if err != nil { return 0, err } err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) if err != nil { - logger.SysError("Redis set user quota error: " + err.Error()) + logger.Error(ctx, "Redis set user quota error: "+err.Error()) } return } -func CacheGetUserQuota(id int) (quota int, err error) { +func CacheGetUserQuota(ctx context.Context, id int) (quota int, err error) { if !common.RedisEnabled { return GetUserQuota(id) } quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id)) if err != nil { - return fetchAndUpdateUserQuota(id) + return fetchAndUpdateUserQuota(ctx, id) } quota, err = strconv.Atoi(quotaString) - if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db - return fetchAndUpdateUserQuota(id) + if err != nil { + return 0, nil } - return quota, err + if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db + logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id) + return fetchAndUpdateUserQuota(ctx, id) + } + return quota, nil } -func CacheUpdateUserQuota(id int) error { +func CacheUpdateUserQuota(ctx context.Context, id int) error { if !common.RedisEnabled { return nil } - quota, err := CacheGetUserQuota(id) + quota, err := CacheGetUserQuota(ctx, id) if err != nil { return err } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index ee8771c9..3d5086e9 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -22,6 +22,7 @@ import ( ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() audioModel := "whisper-1" tokenId := c.GetInt("token_id") @@ -58,7 +59,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus default: preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) } - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 89fc69ce..d0cb424f 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -118,7 +118,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) if err != nil { return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } @@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error consuming token remain quota: "+err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } diff --git a/relay/controller/image.go b/relay/controller/image.go index 3ce3809b..a0428ecb 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -79,7 +79,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus modelRatio := common.GetModelRatio(imageRequest.Model) groupRatio := common.GetGroupRatio(meta.Group) ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(meta.UserId) + userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) quota := int(ratio*imageCostRatio*1000) * imageRequest.N @@ -125,7 +125,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(meta.UserId) + err = model.CacheUpdateUserQuota(ctx, meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } diff --git a/relay/util/common.go b/relay/util/common.go index dbd724b4..ff0e293d 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -161,7 +161,7 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(ctx, userId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } From e99150bdb99c5d8db901195479a48fa53a9c94ab Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 13 Mar 2024 20:00:51 +0800 Subject: [PATCH 04/22] fix: make quota int64 --- common/config/config.go | 10 +++++----- common/utils.go | 2 +- controller/billing.go | 6 +++--- model/cache.go | 8 ++++---- model/channel.go | 4 ++-- model/log.go | 6 +++--- model/option.go | 20 ++++++++++---------- model/redemption.go | 4 ++-- model/token.go | 16 ++++++++-------- model/user.go | 22 +++++++++++----------- model/utils.go | 10 +++++----- relay/controller/audio.go | 10 +++++----- relay/controller/helper.go | 14 +++++++------- relay/controller/image.go | 2 +- relay/util/billing.go | 2 +- relay/util/common.go | 4 ++-- 16 files changed, 70 insertions(+), 70 deletions(-) diff --git a/common/config/config.go b/common/config/config.go index 53af824f..c62a6ac6 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -76,14 +76,14 @@ var MessagePusherToken = "" var TurnstileSiteKey = "" var TurnstileSecretKey = "" -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 +var QuotaForNewUser int64 = 0 +var QuotaForInviter int64 = 0 +var QuotaForInvitee int64 = 0 var ChannelDisableThreshold = 5.0 var AutomaticDisableChannelEnabled = false var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 +var QuotaRemindThreshold int64 = 1000 +var PreConsumedQuota int64 = 500 var ApproximateTokenEnabled = false var RetryTimes = 0 diff --git a/common/utils.go b/common/utils.go index 24615225..ecee2c8e 100644 --- a/common/utils.go +++ b/common/utils.go @@ -5,7 +5,7 @@ import ( "github.com/songquanpeng/one-api/common/config" ) -func LogQuota(quota int) string { +func LogQuota(quota int64) string { if config.DisplayInCurrencyEnabled { return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { diff --git a/controller/billing.go b/controller/billing.go index 7317913d..dd518678 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -8,8 +8,8 @@ import ( ) func GetSubscription(c *gin.Context) { - var remainQuota int - var usedQuota int + var remainQuota int64 + var usedQuota int64 var err error var token *model.Token var expiredTime int64 @@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) { } func GetUsage(c *gin.Context) { - var quota int + var quota int64 var err error var token *model.Token if config.DisplayTokenStatEnabled { diff --git a/model/cache.go b/model/cache.go index 9fd29e7a..dd20d857 100644 --- a/model/cache.go +++ b/model/cache.go @@ -71,7 +71,7 @@ func CacheGetUserGroup(id int) (group string, err error) { return group, err } -func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int, err error) { +func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) { quota, err = GetUserQuota(id) if err != nil { return 0, err @@ -83,7 +83,7 @@ func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int, err error) return } -func CacheGetUserQuota(ctx context.Context, id int) (quota int, err error) { +func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) { if !common.RedisEnabled { return GetUserQuota(id) } @@ -91,7 +91,7 @@ func CacheGetUserQuota(ctx context.Context, id int) (quota int, err error) { if err != nil { return fetchAndUpdateUserQuota(ctx, id) } - quota, err = strconv.Atoi(quotaString) + quota, err = strconv.ParseInt(quotaString, 10, 64) if err != nil { return 0, nil } @@ -114,7 +114,7 @@ func CacheUpdateUserQuota(ctx context.Context, id int) error { return err } -func CacheDecreaseUserQuota(id int, quota int) error { +func CacheDecreaseUserQuota(id int, quota int64) error { if !common.RedisEnabled { return nil } diff --git a/model/channel.go b/model/channel.go index 605c6d17..fc4905b1 100644 --- a/model/channel.go +++ b/model/channel.go @@ -178,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) { } } -func UpdateChannelUsedQuota(id int, quota int) { +func UpdateChannelUsedQuota(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return @@ -186,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) { updateChannelUsedQuota(id, quota) } -func updateChannelUsedQuota(id int, quota int) { +func updateChannelUsedQuota(id int, quota int64) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { logger.SysError("failed to update channel used quota: " + err.Error()) diff --git a/model/log.go b/model/log.go index 9615c237..85c0ba90 100644 --- a/model/log.go +++ b/model/log.go @@ -51,7 +51,7 @@ func RecordLog(userId int, logType int, content string) { } } -func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { +func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { return @@ -66,7 +66,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke CompletionTokens: completionTokens, TokenName: tokenName, ModelName: modelName, - Quota: quota, + Quota: int(quota), ChannelId: channelId, } err := DB.Create(log).Error @@ -137,7 +137,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { return logs, err } -func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { +func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { tx := DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) diff --git a/model/option.go b/model/option.go index e129b9f0..1d1c28b4 100644 --- a/model/option.go +++ b/model/option.go @@ -61,11 +61,11 @@ func InitOptionMap() { config.OptionMap["MessagePusherToken"] = "" config.OptionMap["TurnstileSiteKey"] = "" config.OptionMap["TurnstileSecretKey"] = "" - config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) - config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) - config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) - config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) - config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10) + config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10) + config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) + config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) + config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() @@ -193,15 +193,15 @@ func updateOptionMap(key string, value string) (err error) { case "TurnstileSecretKey": config.TurnstileSecretKey = value case "QuotaForNewUser": - config.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInviter": - config.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64) case "QuotaForInvitee": - config.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64) case "QuotaRemindThreshold": - config.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64) case "PreConsumedQuota": - config.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64) case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": diff --git a/model/redemption.go b/model/redemption.go index 2c5a4141..e0ae68e2 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -14,7 +14,7 @@ type Redemption struct { Key string `json:"key" gorm:"type:char(32);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` - Quota int `json:"quota" gorm:"default:100"` + Quota int64 `json:"quota" gorm:"default:100"` CreatedTime int64 `json:"created_time" gorm:"bigint"` RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"` Count int `json:"count" gorm:"-:all"` // only for api request @@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) { return &redemption, err } -func Redeem(key string, userId int) (quota int, err error) { +func Redeem(key string, userId int) (quota int64, err error) { if key == "" { return 0, errors.New("未提供兑换码") } diff --git a/model/token.go b/model/token.go index c4669e0b..40d0eb8f 100644 --- a/model/token.go +++ b/model/token.go @@ -20,9 +20,9 @@ type Token struct { CreatedTime int64 `json:"created_time" gorm:"bigint"` AccessedTime int64 `json:"accessed_time" gorm:"bigint"` ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int `json:"remain_quota" gorm:"default:0"` + RemainQuota int64 `json:"remain_quota" gorm:"default:0"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota + UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota } func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) { @@ -138,7 +138,7 @@ func DeleteTokenById(id int, userId int) (err error) { return token.Delete() } -func IncreaseTokenQuota(id int, quota int) (err error) { +func IncreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -149,7 +149,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { return increaseTokenQuota(id, quota) } -func increaseTokenQuota(id int, quota int) (err error) { +func increaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), @@ -160,7 +160,7 @@ func increaseTokenQuota(id int, quota int) (err error) { return err } -func DecreaseTokenQuota(id int, quota int) (err error) { +func DecreaseTokenQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -171,7 +171,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { return decreaseTokenQuota(id, quota) } -func decreaseTokenQuota(id int, quota int) (err error) { +func decreaseTokenQuota(id int, quota int64) (err error) { err = DB.Model(&Token{}).Where("id = ?", id).Updates( map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), @@ -182,7 +182,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { return err } -func PreConsumeTokenQuota(tokenId int, quota int) (err error) { +func PreConsumeTokenQuota(tokenId int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -232,7 +232,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { return err } -func PostConsumeTokenQuota(tokenId int, quota int) (err error) { +func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { token, err := GetTokenById(tokenId) if quota > 0 { err = DecreaseUserQuota(token.UserId, quota) diff --git a/model/user.go b/model/user.go index dcbd6ff1..e325394b 100644 --- a/model/user.go +++ b/model/user.go @@ -26,8 +26,8 @@ type User struct { WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management - Quota int `json:"quota" gorm:"type:int;default:0"` - UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota + Quota int64 `json:"quota" gorm:"type:int;default:0"` + UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(32);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` @@ -274,12 +274,12 @@ func ValidateAccessToken(token string) (user *User) { return nil } -func GetUserQuota(id int) (quota int, err error) { +func GetUserQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error return quota, err } -func GetUserUsedQuota(id int) (quota int, err error) { +func GetUserUsedQuota(id int) (quota int64, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } @@ -299,7 +299,7 @@ func GetUserGroup(id int) (group string, err error) { return group, err } -func IncreaseUserQuota(id int, quota int) (err error) { +func IncreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -310,12 +310,12 @@ func IncreaseUserQuota(id int, quota int) (err error) { return increaseUserQuota(id, quota) } -func increaseUserQuota(id int, quota int) (err error) { +func increaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } -func DecreaseUserQuota(id int, quota int) (err error) { +func DecreaseUserQuota(id int, quota int64) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } @@ -326,7 +326,7 @@ func DecreaseUserQuota(id int, quota int) (err error) { return decreaseUserQuota(id, quota) } -func decreaseUserQuota(id int, quota int) (err error) { +func decreaseUserQuota(id int, quota int64) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } @@ -336,7 +336,7 @@ func GetRootUserEmail() (email string) { return email } -func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { +func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) { if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) @@ -345,7 +345,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { updateUserUsedQuotaAndRequestCount(id, quota, 1) } -func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { +func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), @@ -357,7 +357,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { } } -func updateUserUsedQuota(id int, quota int) { +func updateUserUsedQuota(id int, quota int64) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), diff --git a/model/utils.go b/model/utils.go index d481973a..a55eb4b6 100644 --- a/model/utils.go +++ b/model/utils.go @@ -16,12 +16,12 @@ const ( BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock ) -var batchUpdateStores []map[int]int +var batchUpdateStores []map[int]int64 var batchUpdateLocks []sync.Mutex func init() { for i := 0; i < BatchUpdateTypeCount; i++ { - batchUpdateStores = append(batchUpdateStores, make(map[int]int)) + batchUpdateStores = append(batchUpdateStores, make(map[int]int64)) batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{}) } } @@ -35,7 +35,7 @@ func InitBatchUpdater() { }() } -func addNewRecord(type_ int, id int, value int) { +func addNewRecord(type_ int, id int, value int64) { batchUpdateLocks[type_].Lock() defer batchUpdateLocks[type_].Unlock() if _, ok := batchUpdateStores[type_][id]; !ok { @@ -50,7 +50,7 @@ func batchUpdate() { for i := 0; i < BatchUpdateTypeCount; i++ { batchUpdateLocks[i].Lock() store := batchUpdateStores[i] - batchUpdateStores[i] = make(map[int]int) + batchUpdateStores[i] = make(map[int]int64) batchUpdateLocks[i].Unlock() // TODO: maybe we can combine updates with same key? for key, value := range store { @@ -68,7 +68,7 @@ func batchUpdate() { case BatchUpdateTypeUsedQuota: updateUserUsedQuota(key, value) case BatchUpdateTypeRequestCount: - updateUserRequestCount(key, value) + updateUserRequestCount(key, int(value)) case BatchUpdateTypeChannelUsedQuota: updateChannelUsedQuota(key, value) } diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 3d5086e9..155954d2 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -50,14 +50,14 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus modelRatio := common.GetModelRatio(audioModel) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio - var quota int - var preConsumedQuota int + var quota int64 + var preConsumedQuota int64 switch relayMode { case constant.RelayModeAudioSpeech: - preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio) + preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: - preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio) + preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) } userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { @@ -184,7 +184,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } - quota = openai.CountTokenText(text, audioModel) + quota = int64(openai.CountTokenText(text, audioModel)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { diff --git a/relay/controller/helper.go b/relay/controller/helper.go index d0cb424f..600a8d65 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -107,15 +107,15 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int return 0 } -func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int { +func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 { preConsumedTokens := config.PreConsumedQuota if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens + preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens) } - return int(float64(preConsumedTokens) * ratio) + return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } - quota := 0 + var quota int64 completionRatio := common.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) + quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { quota = 1 } diff --git a/relay/controller/image.go b/relay/controller/image.go index a0428ecb..20ea0a4c 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -81,7 +81,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int(ratio*imageCostRatio*1000) * imageRequest.N + quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) diff --git a/relay/util/billing.go b/relay/util/billing.go index 1e2b09ea..495d011e 100644 --- a/relay/util/billing.go +++ b/relay/util/billing.go @@ -6,7 +6,7 @@ import ( "github.com/songquanpeng/one-api/model" ) -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) { +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { if preConsumedQuota != 0 { go func(ctx context.Context) { // return pre-consumed quota diff --git a/relay/util/common.go b/relay/util/common.go index ff0e293d..535ef680 100644 --- a/relay/util/common.go +++ b/relay/util/common.go @@ -155,7 +155,7 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin return fullRequestURL } -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { // quotaDelta is remaining quota to be consumed err := model.PostConsumeTokenQuota(tokenId, quotaDelta) if err != nil { @@ -168,7 +168,7 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuo // totalQuota is total quota consumed if totalQuota != 0 { logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota) } From 79d0cd378ac03597261b31a21f5cce8a680c9c2b Mon Sep 17 00:00:00 2001 From: JustSong Date: Wed, 13 Mar 2024 22:56:54 +0800 Subject: [PATCH 05/22] fix: fix baidu system prompt (close #1079) --- common/config/config.go | 26 +++++++++++------------ common/database.go | 6 ++++-- common/env/helper.go | 42 +++++++++++++++++++++++++++++++++++++ common/helper/helper.go | 40 ----------------------------------- common/logger/logger.go | 5 ++++- model/main.go | 7 ++++--- relay/channel/baidu/main.go | 41 +++++++++++++++++++++--------------- relay/controller/text.go | 1 + 8 files changed, 92 insertions(+), 76 deletions(-) create mode 100644 common/env/helper.go diff --git a/common/config/config.go b/common/config/config.go index c62a6ac6..83cfa933 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,7 +1,7 @@ package config import ( - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/env" "os" "strconv" "sync" @@ -94,16 +94,16 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second +var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE") -var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var Theme = env.String("THEME", "default") var ValidThemes = map[string]bool{ "default": true, "berry": true, @@ -112,10 +112,10 @@ var ValidThemes = map[string]bool{ // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 @@ -130,8 +130,8 @@ var ( var RateLimitKeyExpirationDuration = 20 * time.Minute -var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false) -var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10) -var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) -var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024) -var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128) +var EnableMetric = env.Bool("ENABLE_METRIC", false) +var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10) +var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8) +var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024) +var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) diff --git a/common/database.go b/common/database.go index df60bdd5..f2db759f 100644 --- a/common/database.go +++ b/common/database.go @@ -1,10 +1,12 @@ package common -import "github.com/songquanpeng/one-api/common/helper" +import ( + "github.com/songquanpeng/one-api/common/env" +) var UsingSQLite = false var UsingPostgreSQL = false var UsingMySQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/env/helper.go b/common/env/helper.go new file mode 100644 index 00000000..fdb9f827 --- /dev/null +++ b/common/env/helper.go @@ -0,0 +1,42 @@ +package env + +import ( + "os" + "strconv" +) + +func Bool(env string, defaultValue bool) bool { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) == "true" +} + +func Int(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + return defaultValue + } + return num +} + +func Float64(env string, defaultValue float64) float64 { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.ParseFloat(os.Getenv(env), 64) + if err != nil { + return defaultValue + } + return num +} + +func String(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/helper/helper.go b/common/helper/helper.go index 23578842..76db5042 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -3,12 +3,10 @@ package helper import ( "fmt" "github.com/google/uuid" - "github.com/songquanpeng/one-api/common/logger" "html/template" "log" "math/rand" "net" - "os" "os/exec" "runtime" "strconv" @@ -195,44 +193,6 @@ func Max(a int, b int) int { } } -func GetOrDefaultEnvBool(env string, defaultValue bool) bool { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) == "true" -} - -func GetOrDefaultEnvInt(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.ParseFloat(os.Getenv(env), 64) - if err != nil { - logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultEnvString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func AssignOrDefault(value string, defaultValue string) string { if len(value) != 0 { return value diff --git a/common/logger/logger.go b/common/logger/logger.go index 41b98ca3..ad0a0bea 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "io" "log" "os" @@ -54,7 +55,9 @@ func SysError(s string) { } func Debug(ctx context.Context, msg string) { - logHelper(ctx, loggerDEBUG, msg) + if config.DebugEnabled { + logHelper(ctx, loggerDEBUG, msg) + } } func Info(ctx context.Context, msg string) { diff --git a/model/main.go b/model/main.go index f27cdb6f..05150fd9 100644 --- a/model/main.go +++ b/model/main.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/driver/mysql" @@ -81,9 +82,9 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { return nil diff --git a/relay/channel/baidu/main.go b/relay/channel/baidu/main.go index 4f2b13fc..9ca9e47d 100644 --- a/relay/channel/baidu/main.go +++ b/relay/channel/baidu/main.go @@ -32,9 +32,16 @@ type Message struct { } type ChatRequest struct { - Messages []Message `json:"messages"` - Stream bool `json:"stream"` - UserId string `json:"user_id,omitempty"` + Messages []Message `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + PenaltyScore float64 `json:"penalty_score,omitempty"` + Stream bool `json:"stream,omitempty"` + System string `json:"system,omitempty"` + DisableSearch bool `json:"disable_search,omitempty"` + EnableCitation bool `json:"enable_citation,omitempty"` + MaxOutputTokens int `json:"max_output_tokens,omitempty"` + UserId string `json:"user_id,omitempty"` } type Error struct { @@ -45,28 +52,28 @@ type Error struct { var baiduTokenStore sync.Map func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { - messages := make([]Message, 0, len(request.Messages)) + baiduRequest := ChatRequest{ + Messages: make([]Message, 0, len(request.Messages)), + Temperature: request.Temperature, + TopP: request.TopP, + PenaltyScore: request.FrequencyPenalty, + Stream: request.Stream, + DisableSearch: false, + EnableCitation: false, + MaxOutputTokens: request.MaxTokens, + UserId: request.User, + } for _, message := range request.Messages { if message.Role == "system" { - messages = append(messages, Message{ - Role: "user", - Content: message.StringContent(), - }) - messages = append(messages, Message{ - Role: "assistant", - Content: "Okay", - }) + baiduRequest.System = message.StringContent() } else { - messages = append(messages, Message{ + baiduRequest.Messages = append(baiduRequest.Messages, Message{ Role: message.Role, Content: message.StringContent(), }) } } - return &ChatRequest{ - Messages: messages, - Stream: request.Stream, - } + return &baiduRequest } func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse { diff --git a/relay/controller/text.go b/relay/controller/text.go index 781170f4..ba008713 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -74,6 +74,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) } + logger.Debugf(ctx, "converted request: \n%s", string(jsonData)) requestBody = bytes.NewBuffer(jsonData) } From 2dcef8528582bb8cacf7c53f3ec2a5b3577fe6bb Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 01:02:47 +0800 Subject: [PATCH 06/22] feat: support ollama now (close #870) --- README.md | 1 + common/constants.go | 2 + common/helper/helper.go | 4 + common/logger/logger.go | 4 + middleware/request-id.go | 2 +- relay/channel/ollama/adaptor.go | 65 +++++++ relay/channel/ollama/constants.go | 5 + relay/channel/ollama/main.go | 178 ++++++++++++++++++ relay/channel/ollama/model.go | 37 ++++ relay/constant/api_type.go | 3 + relay/helper/main.go | 3 + web/berry/src/constants/ChannelConstants.js | 6 + web/berry/src/views/Channel/type/Config.js | 3 + .../src/constants/channel.constants.js | 1 + 14 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 relay/channel/ollama/adaptor.go create mode 100644 relay/channel/ollama/constants.go create mode 100644 relay/channel/ollama/main.go create mode 100644 relay/channel/ollama/model.go diff --git a/README.md b/README.md index 1cb30591..8f6c6bf7 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [x] [MINIMAX](https://api.minimax.chat/) + [x] [Groq](https://wow.groq.com/) + + [x] [Ollama](https://github.com/ollama/ollama) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 diff --git a/common/constants.go b/common/constants.go index de71bc7a..f4f575ba 100644 --- a/common/constants.go +++ b/common/constants.go @@ -69,6 +69,7 @@ const ( ChannelTypeMinimax ChannelTypeMistral ChannelTypeGroq + ChannelTypeOllama ChannelTypeDummy ) @@ -104,6 +105,7 @@ var ChannelBaseURLs = []string{ "https://api.minimax.chat", // 27 "https://api.mistral.ai", // 28 "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 } const ( diff --git a/common/helper/helper.go b/common/helper/helper.go index 76db5042..db41ac74 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -185,6 +185,10 @@ func GetTimeString() string { return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) } +func GenRequestID() string { + return GetTimeString() + GetRandomNumberString(8) +} + func Max(a int, b int) int { if a >= b { return a diff --git a/common/logger/logger.go b/common/logger/logger.go index ad0a0bea..957d8a11 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" "io" "log" "os" @@ -94,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) { writer = gin.DefaultWriter } id := ctx.Value(RequestIdKey) + if id == nil { + id = helper.GenRequestID() + } now := time.Now() _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) if !setupLogWorking { diff --git a/middleware/request-id.go b/middleware/request-id.go index 234a93d8..a4c49ddb 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -9,7 +9,7 @@ import ( func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := helper.GetTimeString() + helper.GetRandomNumberString(8) + id := helper.GenRequestID() c.Set(logger.RequestIdKey, id) ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go new file mode 100644 index 00000000..06c66101 --- /dev/null +++ b/relay/channel/ollama/adaptor.go @@ -0,0 +1,65 @@ +package ollama + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/util" + "io" + "net/http" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(meta *util.RelayMeta) { + +} + +func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { + // https://github.com/ollama/ollama/blob/main/docs/api.md + fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) + return fullRequestURL, nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { + channel.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeEmbeddings: + return nil, errors.New("not supported") + default: + return ConvertRequest(*request), nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { + return channel.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + if meta.IsStream { + err, usage = StreamHandler(c, resp) + } else { + err, usage = Handler(c, resp) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "ollama" +} diff --git a/relay/channel/ollama/constants.go b/relay/channel/ollama/constants.go new file mode 100644 index 00000000..32f82b2a --- /dev/null +++ b/relay/channel/ollama/constants.go @@ -0,0 +1,5 @@ +package ollama + +var ModelList = []string{ + "qwen:0.5b-chat", +} diff --git a/relay/channel/ollama/main.go b/relay/channel/ollama/main.go new file mode 100644 index 00000000..7ec646a3 --- /dev/null +++ b/relay/channel/ollama/main.go @@ -0,0 +1,178 @@ +package ollama + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" +) + +func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { + ollamaRequest := ChatRequest{ + Model: request.Model, + Options: &Options{ + Seed: int(request.Seed), + Temperature: request.Temperature, + TopP: request.TopP, + FrequencyPenalty: request.FrequencyPenalty, + PresencePenalty: request.PresencePenalty, + }, + Stream: request.Stream, + } + for _, message := range request.Messages { + ollamaRequest.Messages = append(ollamaRequest.Messages, Message{ + Role: message.Role, + Content: message.StringContent(), + }) + } + return &ollamaRequest +} + +func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { + choice := openai.TextResponseChoice{ + Index: 0, + Message: model.Message{ + Role: response.Message.Role, + Content: response.Message.Content, + }, + } + if response.Done { + choice.FinishReason = "stop" + } + fullTextResponse := openai.TextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion", + Created: helper.GetTimestamp(), + Choices: []openai.TextResponseChoice{choice}, + Usage: model.Usage{ + PromptTokens: response.PromptEvalCount, + CompletionTokens: response.EvalCount, + TotalTokens: response.PromptEvalCount + response.EvalCount, + }, + } + return &fullTextResponse +} + +func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Role = ollamaResponse.Message.Role + choice.Delta.Content = ollamaResponse.Message.Content + if ollamaResponse.Done { + choice.FinishReason = &constant.StopFinishReason + } + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: ollamaResponse.Model, + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + return &response +} + +func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var usage model.Usage + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "}\n"); i >= 0 { + return i + 2, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := strings.TrimPrefix(scanner.Text(), "}") + dataChan <- data + "}" + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + var ollamaResponse ChatResponse + err := json.Unmarshal([]byte(data), &ollamaResponse) + if err != nil { + logger.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + if ollamaResponse.EvalCount != 0 { + usage.PromptTokens = ollamaResponse.PromptEvalCount + usage.CompletionTokens = ollamaResponse.EvalCount + usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount + } + response := streamResponseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, &usage +} + +func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + ctx := context.TODO() + var ollamaResponse ChatResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + logger.Debugf(ctx, "ollama response: %s", string(responseBody)) + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &ollamaResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if ollamaResponse.Error != "" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: ollamaResponse.Error, + Type: "ollama_error", + Param: "", + Code: "ollama_error", + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseOllama2OpenAI(&ollamaResponse) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &fullTextResponse.Usage +} diff --git a/relay/channel/ollama/model.go b/relay/channel/ollama/model.go new file mode 100644 index 00000000..a8ef1ffc --- /dev/null +++ b/relay/channel/ollama/model.go @@ -0,0 +1,37 @@ +package ollama + +type Options struct { + Seed int `json:"seed,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` + Images []string `json:"images,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model,omitempty"` + Messages []Message `json:"messages,omitempty"` + Stream bool `json:"stream"` + Options *Options `json:"options,omitempty"` +} + +type ChatResponse struct { + Model string `json:"model,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + Message Message `json:"message,omitempty"` + Response string `json:"response,omitempty"` // for stream response + Done bool `json:"done,omitempty"` + TotalDuration int `json:"total_duration,omitempty"` + LoadDuration int `json:"load_duration,omitempty"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + EvalDuration int `json:"eval_duration,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index d2184dac..b249f6a2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,6 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini + APITypeOllama APITypeDummy // this one is only for count, do not add any channel after this ) @@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeTencent case common.ChannelTypeGemini: apiType = APITypeGemini + case common.ChannelTypeOllama: + apiType = APITypeOllama } return apiType } diff --git a/relay/helper/main.go b/relay/helper/main.go index c2b6e6af..e7342329 100644 --- a/relay/helper/main.go +++ b/relay/helper/main.go @@ -7,6 +7,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/anthropic" "github.com/songquanpeng/one-api/relay/channel/baidu" "github.com/songquanpeng/one-api/relay/channel/gemini" + "github.com/songquanpeng/one-api/relay/channel/ollama" "github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/channel/palm" "github.com/songquanpeng/one-api/relay/channel/tencent" @@ -37,6 +38,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &xunfei.Adaptor{} case constant.APITypeZhipu: return &zhipu.Adaptor{} + case constant.APITypeOllama: + return &ollama.Adaptor{} } return nil } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 8e9fc97c..c0379381 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -95,6 +95,12 @@ export const CHANNEL_OPTIONS = { value: 29, color: 'default' }, + 30: { + key: 30, + text: 'Ollama', + value: 30, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index 897db189..c42c0253 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -166,6 +166,9 @@ const typeConfig = { 29: { modelGroup: "groq", }, + 30: { + modelGroup: "ollama", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index f6db46c3..c8284ef2 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -15,6 +15,7 @@ export const CHANNEL_OPTIONS = [ { key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, From 89e111ac69f120c9213896430cf613e6cc201644 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 01:17:19 +0800 Subject: [PATCH 07/22] ci: fix ci condition --- .github/workflows/docker-image-amd64-en.yml | 1 + .github/workflows/docker-image-amd64.yml | 1 + .github/workflows/docker-image-arm64.yml | 1 + .github/workflows/linux-release.yml | 1 + .github/workflows/macos-release.yml | 1 + .github/workflows/windows-release.yml | 1 + 6 files changed, 6 insertions(+) diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index 44dc0bc0..fc0e8994 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -4,6 +4,7 @@ on: push: tags: - '*' + - '!*-pro*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index e3b8439a..983cd877 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -4,6 +4,7 @@ on: push: tags: - '*' + - '!*-pro*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index d6449eb8..d756830f 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -5,6 +5,7 @@ on: tags: - '*' - '!*-alpha*' + - '!*-pro*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index e81ab09f..b40bd629 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -7,6 +7,7 @@ on: tags: - '*' - '!*-alpha*' + - '!*-pro*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 13415276..166f15f8 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -7,6 +7,7 @@ on: tags: - '*' - '!*-alpha*' + - '!*-pro*' workflow_dispatch: inputs: name: diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 8b1160b4..02908d15 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -7,6 +7,7 @@ on: tags: - '*' - '!*-alpha*' + - '!*-pro*' workflow_dispatch: inputs: name: From be9eb59fbbcfb87d56f1e9fed7fbabd37c1d96bb Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:11:36 +0800 Subject: [PATCH 08/22] feat: support lingyiwanwu --- README.md | 1 + common/constants.go | 2 ++ common/model-ratio.go | 4 ++++ relay/channel/lingyiwanwu/constants.go | 9 +++++++++ relay/channel/openai/compatible.go | 4 ++++ web/berry/src/constants/ChannelConstants.js | 6 ++++++ web/berry/src/views/Channel/type/Config.js | 3 +++ web/default/src/constants/channel.constants.js | 1 + 8 files changed, 30 insertions(+) create mode 100644 relay/channel/lingyiwanwu/constants.go diff --git a/README.md b/README.md index 8f6c6bf7..0ba659c4 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 + [x] [MINIMAX](https://api.minimax.chat/) + [x] [Groq](https://wow.groq.com/) + [x] [Ollama](https://github.com/ollama/ollama) + + [x] [零一万物](https://platform.lingyiwanwu.com/) 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 3. 支持通过**负载均衡**的方式访问多个渠道。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 diff --git a/common/constants.go b/common/constants.go index f4f575ba..849bdce7 100644 --- a/common/constants.go +++ b/common/constants.go @@ -70,6 +70,7 @@ const ( ChannelTypeMistral ChannelTypeGroq ChannelTypeOllama + ChannelTypeLingYiWanWu ChannelTypeDummy ) @@ -106,6 +107,7 @@ var ChannelBaseURLs = []string{ "https://api.mistral.ai", // 28 "https://api.groq.com/openai", // 29 "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 } const ( diff --git a/common/model-ratio.go b/common/model-ratio.go index 5b0a759b..c1b2856d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -130,6 +130,10 @@ var ModelRatio = map[string]float64{ "llama2-7b-2048": 0.1 / 1000 * USD, "mixtral-8x7b-32768": 0.27 / 1000 * USD, "gemma-7b-it": 0.1 / 1000 * USD, + // https://platform.lingyiwanwu.com/docs#-计费单元 + "yi-34b-chat-0205": 2.5 / 1000000 * RMB, + "yi-34b-chat-200k": 12.0 / 1000000 * RMB, + "yi-vl-plus": 6.0 / 1000000 * RMB, } var CompletionRatio = map[string]float64{} diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/channel/lingyiwanwu/constants.go new file mode 100644 index 00000000..30000e9d --- /dev/null +++ b/relay/channel/lingyiwanwu/constants.go @@ -0,0 +1,9 @@ +package lingyiwanwu + +// https://platform.lingyiwanwu.com/docs + +var ModelList = []string{ + "yi-34b-chat-0205", + "yi-34b-chat-200k", + "yi-vl-plus", +} diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go index 767eec4b..e4951a34 100644 --- a/relay/channel/openai/compatible.go +++ b/relay/channel/openai/compatible.go @@ -5,6 +5,7 @@ import ( "github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/groq" + "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" "github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" @@ -18,6 +19,7 @@ var CompatibleChannels = []int{ common.ChannelTypeMinimax, common.ChannelTypeMistral, common.ChannelTypeGroq, + common.ChannelTypeLingYiWanWu, } func GetCompatibleChannelMeta(channelType int) (string, []string) { @@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { return "mistralai", mistral.ModelList case common.ChannelTypeGroq: return "groq", groq.ModelList + case common.ChannelTypeLingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList default: return "openai", ModelList } diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index c0379381..06597b93 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -101,6 +101,12 @@ export const CHANNEL_OPTIONS = { value: 30, color: 'default' }, + 31: { + key: 31, + text: '零一万物', + value: 31, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index c42c0253..8dfe77a4 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -169,6 +169,9 @@ const typeConfig = { 30: { modelGroup: "ollama", }, + 31: { + modelGroup: "lingyiwanwu", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index c8284ef2..8d536e58 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -16,6 +16,7 @@ export const CHANNEL_OPTIONS = [ { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 29, text: 'Groq', value: 29, color: 'orange' }, { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, From e3767cbb07b83fc28a41a5257a85cd34acfb209c Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:13:05 +0800 Subject: [PATCH 09/22] fix: fix haiku model name (close #1149) --- common/model-ratio.go | 2 +- relay/channel/anthropic/constants.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index c1b2856d..044f4f80 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{ "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD, - "claude-3-haiku-20240229": 0.25 / 1000 * USD, + "claude-3-haiku-20240307": 0.25 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 diff --git a/relay/channel/anthropic/constants.go b/relay/channel/anthropic/constants.go index fcc0c2a3..cadcedc8 100644 --- a/relay/channel/anthropic/constants.go +++ b/relay/channel/anthropic/constants.go @@ -2,7 +2,7 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", - "claude-3-haiku-20240229", + "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229", } From c28ec1079504529fdc506560de8c74b18b15907a Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:14:39 +0800 Subject: [PATCH 10/22] fix: fix cors for dashboard api --- router/dashboard.go | 1 + 1 file changed, 1 insertion(+) diff --git a/router/dashboard.go b/router/dashboard.go index 0b539d44..5952d698 100644 --- a/router/dashboard.go +++ b/router/dashboard.go @@ -9,6 +9,7 @@ import ( func SetDashboardRouter(router *gin.Engine) { apiRouter := router.Group("/") + apiRouter.Use(middleware.CORS()) apiRouter.Use(gzip.Gzip(gzip.DefaultCompression)) apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.TokenAuth()) From f33555ae78b96c86834e5c77e0bc7098fc6bae81 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:17:19 +0800 Subject: [PATCH 11/22] fix: update max token for test (close #1154) --- controller/channel-test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 6d18305a..67ac91d0 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -30,7 +30,7 @@ import ( func buildTestRequest() *relaymodel.GeneralOpenAIRequest { testRequest := &relaymodel.GeneralOpenAIRequest{ - MaxTokens: 1, + MaxTokens: 2, Stream: false, Model: "gpt-3.5-turbo", } From b16917386043bbce6db52ea3bf4f2c8ff57caf0d Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:20:38 +0800 Subject: [PATCH 12/22] fix: force set Accept header for ali stream request (close #1151) --- relay/channel/ali/adaptor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 6c6f433e..6a3245ad 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -32,6 +32,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { channel.SetupCommonRequestHeader(c, req, meta) + if meta.IsStream { + req.Header.Set("Accept", "text/event-stream") + } req.Header.Set("Authorization", "Bearer "+meta.APIKey) if meta.IsStream { req.Header.Set("X-DashScope-SSE", "enable") From 8ede66a896053fae4b233fb97701ec3f4919c456 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:27:47 +0800 Subject: [PATCH 13/22] fix: fix ci --- .github/workflows/docker-image-amd64-en.yml | 8 +++++++- .github/workflows/docker-image-amd64.yml | 8 +++++++- .github/workflows/docker-image-arm64.yml | 8 +++++++- .github/workflows/linux-release.yml | 7 ++++++- .github/workflows/macos-release.yml | 7 ++++++- .github/workflows/windows-release.yml | 7 ++++++- 6 files changed, 39 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index fc0e8994..af488256 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -4,7 +4,6 @@ on: push: tags: - '*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -21,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index 983cd877..2079d31f 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -4,7 +4,6 @@ on: push: tags: - '*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -21,6 +20,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index d756830f..39d1a401 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -5,7 +5,6 @@ on: tags: - '*' - '!*-alpha*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -22,6 +21,13 @@ jobs: - name: Check out the repo uses: actions/checkout@v3 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi + - name: Save version info run: | git describe --tags > VERSION diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index b40bd629..6f30a1d5 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -7,7 +7,6 @@ on: tags: - '*' - '!*-alpha*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -21,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 166f15f8..359c2c92 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -7,7 +7,6 @@ on: tags: - '*' - '!*-alpha*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -21,6 +20,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 02908d15..4e99b75c 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -7,7 +7,6 @@ on: tags: - '*' - '!*-alpha*' - - '!*-pro*' workflow_dispatch: inputs: name: @@ -24,6 +23,12 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Check repository URL + run: | + REPO_URL=$(git config --get remote.origin.url) + if [[ $REPO_URL == *"pro" ]]; then + exit 1 + fi - uses: actions/setup-node@v3 with: node-version: 16 From 66efabd5ae540d60e91e6bf3af89b7a9bad06231 Mon Sep 17 00:00:00 2001 From: Jguobao <779188083@qq.com> Date: Thu, 14 Mar 2024 23:31:07 +0800 Subject: [PATCH 14/22] fix: fix baidu url check (#1143) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加百度的另外3个向量模型【"bge-large-zh", "bge-large-en", "tao-8k", 】 --- relay/channel/baidu/adaptor.go | 21 +++++++++++++++++---- relay/channel/baidu/constants.go | 3 +++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 1a96997a..2d2e24f6 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -3,14 +3,15 @@ package baidu import ( "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" - "strings" ) type Adaptor struct { @@ -23,7 +24,13 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t suffix := "chat/" - if strings.HasPrefix("Embedding", meta.ActualModelName) { + if strings.HasPrefix(meta.ActualModelName, "Embedding") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "bge-large") { + suffix = "embeddings/" + } + if strings.HasPrefix(meta.ActualModelName, "tao-8k") { suffix = "embeddings/" } switch meta.ActualModelName { @@ -45,6 +52,12 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { suffix += "bloomz_7b1" case "Embedding-V1": suffix += "embedding-v1" + case "bge-large-zh": + suffix += "bge_large_zh" + case "bge-large-en": + suffix += "bge_large_en" + case "tao-8k": + suffix += "tao_8k" default: suffix += meta.ActualModelName } diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index 0fa8f2d6..45a4e901 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -7,4 +7,7 @@ var ModelList = []string{ "ERNIE-Speed", "ERNIE-Bot-turbo", "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", } From 7cd57f3125ce60ed0d61bed1edf86409b1be7906 Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:36:10 +0800 Subject: [PATCH 15/22] chore: update ratio for baidu embedding --- common/model-ratio.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/common/model-ratio.go b/common/model-ratio.go index 044f4f80..4a1a0013 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens "ERNIE-Bot-8k": 0.024 * RMB, "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "bge-large-8k": 0.002 * RMB, "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens From 0926b6206bd1c913ac1c7fbc24dba33d385b0426 Mon Sep 17 00:00:00 2001 From: afafw <82748932+afafw@users.noreply.github.com> Date: Thu, 14 Mar 2024 23:44:46 +0800 Subject: [PATCH 16/22] chore: update client name (#934) --- web/berry/src/views/Token/component/TableRow.js | 2 +- web/default/src/components/TokensTable.js | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/web/berry/src/views/Token/component/TableRow.js b/web/berry/src/views/Token/component/TableRow.js index 19594b4c..2753764c 100644 --- a/web/berry/src/views/Token/component/TableRow.js +++ b/web/berry/src/views/Token/component/TableRow.js @@ -31,7 +31,7 @@ const COPY_OPTIONS = [ url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}', encode: false }, - { key: 'ama', text: 'AMA 问天', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true }, + { key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true }, { key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true } ]; diff --git a/web/default/src/components/TokensTable.js b/web/default/src/components/TokensTable.js index db4745e4..d6ad2a21 100644 --- a/web/default/src/components/TokensTable.js +++ b/web/default/src/components/TokensTable.js @@ -8,12 +8,12 @@ import { renderQuota } from '../helpers/render'; const COPY_OPTIONS = [ { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, - { key: 'ama', text: 'AMA 问天', value: 'ama' }, + { key: 'ama', text: 'BotGem', value: 'ama' }, { key: 'opencat', text: 'OpenCat', value: 'opencat' }, ]; const OPEN_LINK_OPTIONS = [ - { key: 'ama', text: 'AMA 问天', value: 'ama' }, + { key: 'ama', text: 'BotGem', value: 'ama' }, { key: 'opencat', text: 'OpenCat', value: 'opencat' }, ]; From 3edf7247c49b788d8d268d22d30730e299fb0d09 Mon Sep 17 00:00:00 2001 From: "E.da" <46555117+yooyui@users.noreply.github.com> Date: Thu, 14 Mar 2024 23:45:50 +0800 Subject: [PATCH 17/22] fix: fix theme berry copy (#1148) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 调整berry主题页脚`label`表述 --- web/berry/src/views/Setting/component/OtherSetting.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/berry/src/views/Setting/component/OtherSetting.js b/web/berry/src/views/Setting/component/OtherSetting.js index 01f92f77..426b8c81 100644 --- a/web/berry/src/views/Setting/component/OtherSetting.js +++ b/web/berry/src/views/Setting/component/OtherSetting.js @@ -265,7 +265,7 @@ const OtherSetting = () => { multiline maxRows={15} id="Footer" - label="公告" + label="页脚" value={inputs.Footer} name="Footer" onChange={handleInputChange} From 3e2e805d61981fb73c8e03e661e3c5865bbb34d4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Mar 2024 23:46:17 +0800 Subject: [PATCH 18/22] chore(deps): bump google.golang.org/protobuf from 1.30.0 to 1.33.0 (#1145) Bumps google.golang.org/protobuf from 1.30.0 to 1.33.0. --- updated-dependencies: - dependency-name: google.golang.org/protobuf dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 4ab23003..8d255cf0 100644 --- a/go.mod +++ b/go.mod @@ -60,6 +60,6 @@ require ( golang.org/x/net v0.17.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect - google.golang.org/protobuf v1.30.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 21bcddc6..e8d0aad6 100644 --- a/go.sum +++ b/go.sum @@ -177,8 +177,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IV golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= From ebfee3b46cb4e6f9b36a18cfbc7e85e3a96cb913 Mon Sep 17 00:00:00 2001 From: warjiang <1096409085@qq.com> Date: Thu, 14 Mar 2024 23:47:46 +0800 Subject: [PATCH 19/22] feat: add support for private registry in docker-compose.yml (#1103) --- docker-compose.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index 30edb281..1325a818 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.4' services: one-api: - image: justsong/one-api:latest + image: "${REGISTRY:-docker.io}/justsong/one-api:latest" container_name: one-api restart: always command: --log-dir /app/logs @@ -29,12 +29,12 @@ services: retries: 3 redis: - image: redis:latest + image: "${REGISTRY:-docker.io}/redis:latest" container_name: redis restart: always db: - image: mysql:8.2.0 + image: "${REGISTRY:-docker.io}/mysql:8.2.0" restart: always container_name: mysql volumes: From 996f4d99dded958ecc072c30a241e73c22a4c9fb Mon Sep 17 00:00:00 2001 From: JustSong Date: Thu, 14 Mar 2024 23:53:25 +0800 Subject: [PATCH 20/22] ci: fix ci --- .github/workflows/docker-image-amd64-en.yml | 2 +- .github/workflows/docker-image-amd64.yml | 2 +- .github/workflows/docker-image-arm64.yml | 2 +- .github/workflows/linux-release.yml | 2 +- .github/workflows/macos-release.yml | 2 +- .github/workflows/windows-release.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/docker-image-amd64-en.yml b/.github/workflows/docker-image-amd64-en.yml index af488256..92c8a67f 100644 --- a/.github/workflows/docker-image-amd64-en.yml +++ b/.github/workflows/docker-image-amd64-en.yml @@ -24,7 +24,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - name: Save version info diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index 2079d31f..9fd20f46 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -24,7 +24,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - name: Save version info diff --git a/.github/workflows/docker-image-arm64.yml b/.github/workflows/docker-image-arm64.yml index 39d1a401..24b4a4b7 100644 --- a/.github/workflows/docker-image-arm64.yml +++ b/.github/workflows/docker-image-arm64.yml @@ -25,7 +25,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - name: Save version info diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml index 6f30a1d5..4b5694e7 100644 --- a/.github/workflows/linux-release.yml +++ b/.github/workflows/linux-release.yml @@ -24,7 +24,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - uses: actions/setup-node@v3 with: diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml index 359c2c92..8304de05 100644 --- a/.github/workflows/macos-release.yml +++ b/.github/workflows/macos-release.yml @@ -24,7 +24,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - uses: actions/setup-node@v3 with: diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml index 4e99b75c..eb1cbe21 100644 --- a/.github/workflows/windows-release.yml +++ b/.github/workflows/windows-release.yml @@ -27,7 +27,7 @@ jobs: run: | REPO_URL=$(git config --get remote.origin.url) if [[ $REPO_URL == *"pro" ]]; then - exit 1 + exit 0 fi - uses: actions/setup-node@v3 with: From 752639560fd2a6a5c655f66e3ed6a04159faf8db Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 15 Mar 2024 00:30:15 +0800 Subject: [PATCH 21/22] feat: able to use separated database for table logs --- README.md | 39 +++++++++++++++++++------------------ main.go | 16 +++++++++++++++- model/log.go | 24 +++++++++++------------ model/main.go | 53 ++++++++++++++++++++++++++++++--------------------- 4 files changed, 78 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 0ba659c4..d0b6b70c 100644 --- a/README.md +++ b/README.md @@ -349,38 +349,39 @@ graph LR + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 -4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 +4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 +5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 +6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`MEMORY_CACHE_ENABLED=true` -6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 +7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + 例子:`SYNC_FREQUENCY=60` -7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 +8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` -8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 +9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 - + 例子:`CHANNEL_TEST_FREQUENCY=1440` -10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + + 例子:`CHANNEL_TEST_FREQUENCY=1440` +11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -13. 请求频率限制: +14. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -14. 编码器缓存设置: +15. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +17. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +18. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/main.go b/main.go index 83d7e7ed..b20c6daf 100644 --- a/main.go +++ b/main.go @@ -30,11 +30,25 @@ func main() { if config.DebugEnabled { logger.SysLog("running in debug mode") } + var err error // Initialize SQL Database - err := model.InitDB() + model.DB, err = model.InitDB("SQL_DSN") if err != nil { logger.FatalLog("failed to initialize database: " + err.Error()) } + if os.Getenv("LOG_SQL_DSN") != "" { + logger.SysLog("using secondary database for table logs") + model.LOG_DB, err = model.InitDB("LOG_SQL_DSN") + if err != nil { + logger.FatalLog("failed to initialize secondary database: " + err.Error()) + } + } else { + model.LOG_DB = model.DB + } + err = model.CreateRootAccountIfNeed() + if err != nil { + logger.FatalLog("database init error: " + err.Error()) + } defer func() { err := model.CloseDB() if err != nil { diff --git a/model/log.go b/model/log.go index 85c0ba90..4409f73e 100644 --- a/model/log.go +++ b/model/log.go @@ -45,7 +45,7 @@ func RecordLog(userId int, logType int, content string) { Type: logType, Content: content, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.SysError("failed to record log: " + err.Error()) } @@ -69,7 +69,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke Quota: int(quota), ChannelId: channelId, } - err := DB.Create(log).Error + err := LOG_DB.Create(log).Error if err != nil { logger.Error(ctx, "failed to record log: "+err.Error()) } @@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB + tx = LOG_DB } else { - tx = DB.Where("type = ?", logType) + tx = LOG_DB.Where("type = ?", logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { - tx = DB.Where("user_id = ?", userId) + tx = LOG_DB.Where("user_id = ?", userId) } else { - tx = DB.Where("user_id = ? and type = ?", userId, logType) + tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { tx = tx.Where("model_name = ?", modelName) @@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error + err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error + err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { - tx := DB.Table("logs").Select("ifnull(sum(quota),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { - tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") + tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") if username != "" { tx = tx.Where("username = ?", username) } @@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa } func DeleteOldLog(targetTimestamp int64) (int64, error) { - result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) + result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{}) return result.RowsAffected, result.Error } @@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day" } - err = DB.Raw(` + err = LOG_DB.Raw(` SELECT `+groupSelect+`, model_name, count(1) as request_count, sum(quota) as quota, diff --git a/model/main.go b/model/main.go index 05150fd9..ca7a35b2 100644 --- a/model/main.go +++ b/model/main.go @@ -17,8 +17,9 @@ import ( ) var DB *gorm.DB +var LOG_DB *gorm.DB -func createRootAccountIfNeed() error { +func CreateRootAccountIfNeed() error { var user User //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { @@ -41,9 +42,9 @@ func createRootAccountIfNeed() error { return nil } -func chooseDB() (*gorm.DB, error) { - if os.Getenv("SQL_DSN") != "" { - dsn := os.Getenv("SQL_DSN") +func chooseDB(envName string) (*gorm.DB, error) { + if os.Getenv(envName) != "" { + dsn := os.Getenv(envName) if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL logger.SysLog("using PostgreSQL as database") @@ -71,23 +72,22 @@ func chooseDB() (*gorm.DB, error) { }) } -func InitDB() (err error) { - db, err := chooseDB() +func InitDB(envName string) (db *gorm.DB, err error) { + db, err = chooseDB(envName) if err == nil { if config.DebugSQLEnabled { db = db.Debug() } - DB = db - sqlDB, err := DB.DB() + sqlDB, err := db.DB() if err != nil { - return err + return nil, err } sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100)) sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000)) sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60))) if !config.IsMasterNode { - return nil + return db, err } if common.UsingMySQL { _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded @@ -95,46 +95,55 @@ func InitDB() (err error) { logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Token{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&User{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Option{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Redemption{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Ability{}) if err != nil { - return err + return nil, err } err = db.AutoMigrate(&Log{}) if err != nil { - return err + return nil, err } logger.SysLog("database migrated") - err = createRootAccountIfNeed() - return err + return db, err } else { logger.FatalLog(err) } - return err + return db, err } -func CloseDB() error { - sqlDB, err := DB.DB() +func closeDB(db *gorm.DB) error { + sqlDB, err := db.DB() if err != nil { return err } err = sqlDB.Close() return err } + +func CloseDB() error { + if LOG_DB != DB { + err := closeDB(LOG_DB) + if err != nil { + return err + } + } + return closeDB(DB) +} From b204f6d82b73c382e0debdb4f43872b906e07f45 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 15 Mar 2024 00:55:28 +0800 Subject: [PATCH 22/22] docs: update README --- README.md | 39 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index d0b6b70c..0ba659c4 100644 --- a/README.md +++ b/README.md @@ -349,39 +349,38 @@ graph LR + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。 + 如果报错 `Error 1040: Too many connections`,请适当减小该值。 + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。 -4. `LOG_SQL_DSN`:设置之后将为 `logs` 表使用独立的数据库,请使用 MySQL 或 PostgreSQL。 -5. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 +4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` -6. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 +5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`MEMORY_CACHE_ENABLED=true` -7. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 +6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。 + 例子:`SYNC_FREQUENCY=60` -8. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 +7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 + 例子:`NODE_TYPE=slave` -9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 +8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 - + 例子:`CHANNEL_TEST_FREQUENCY=1440` -11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 +9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 + + 例子:`CHANNEL_TEST_FREQUENCY=1440` +10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` -12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 +11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 + 例子:`BATCH_UPDATE_ENABLED=true` + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 -13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 +12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 + 例子:`BATCH_UPDATE_INTERVAL=5` -14. 请求频率限制: +13. 请求频率限制: + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 -15. 编码器缓存设置: +14. 编码器缓存设置: + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 -16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 -17. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 -18. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 -19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 -20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 -21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 -22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 +15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 +16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 +17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。 +18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。 +19. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。 +20. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 +21. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。