mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-15 04:33:42 +08:00
Merge branch 'songquanpeng' into sync_upstream
This commit is contained in:
@@ -68,11 +68,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
}
|
||||
|
||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||
}
|
||||
|
||||
if Usage.CompletionTokens == 0 {
|
||||
return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -134,20 +138,32 @@ func TestChannel(c *gin.Context) {
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
func notifyRootUser(subject string, content string) {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
// enable & notify
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
@@ -170,9 +186,7 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
go func() {
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
tok := time.Now()
|
||||
@@ -181,9 +195,12 @@ func testAllChannels(notify bool) error {
|
||||
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if shouldDisableChannel(openaiErr, -1) {
|
||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
||||
@@ -361,6 +361,24 @@ func init() {
|
||||
Root: "claude-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.0",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.0",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot",
|
||||
Object: "model",
|
||||
@@ -406,6 +424,15 @@ func init() {
|
||||
Root: "PaLM-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
@@ -460,6 +487,24 @@ func init() {
|
||||
Root: "qwen-plus",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max-longcontext",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max-longcontext",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-v1",
|
||||
Object: "model",
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -24,6 +25,11 @@ func RelayChat(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -24,6 +25,11 @@ func RelayCompletions(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
|
||||
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析模型映射
|
||||
var isModelMapped bool
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
|
||||
@@ -24,6 +24,10 @@ func RelayImageGenerations(c *gin.Context) {
|
||||
imageRequest.Model = "dall-e-2"
|
||||
}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
if imageRequest.Size == "" {
|
||||
imageRequest.Size = "1024x1024"
|
||||
}
|
||||
|
||||
@@ -110,9 +110,14 @@ func setChannelToContext(c *gin.Context, channel *model.Channel) {
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||
@@ -131,8 +136,22 @@ func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
@@ -140,11 +159,15 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user