Merge branch 'songquanpeng' into sync_upstream

This commit is contained in:
Martial BE
2023-12-21 15:36:01 +08:00
41 changed files with 1045 additions and 89 deletions

View File

@@ -68,11 +68,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
}
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
}
if Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("channel %s, message 补全 tokens 非预期返回 0", channel.Name)), nil
}
return nil, nil
}
@@ -134,20 +138,32 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
func notifyRootUser(subject string, content string) {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
err := common.SendEmail(subject, common.RootUserEmail, content)
if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify
func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
// enable & notify
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}
func testAllChannels(notify bool) error {
if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
@@ -170,9 +186,7 @@ func testAllChannels(notify bool) error {
}
go func() {
for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled {
continue
}
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now()
@@ -181,9 +195,12 @@ func testAllChannels(notify bool) error {
err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
disableChannel(channel.Id, channel.Name, err.Error())
}
if shouldDisableChannel(openaiErr, -1) {
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) {
disableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(common.RequestInterval)
}

View File

@@ -361,6 +361,24 @@ func init() {
Root: "claude-2",
Parent: nil,
},
{
Id: "claude-2.1",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.1",
Parent: nil,
},
{
Id: "claude-2.0",
Object: "model",
Created: 1677649963,
OwnedBy: "anthropic",
Permission: permission,
Root: "claude-2.0",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
@@ -406,6 +424,15 @@ func init() {
Root: "PaLM-2",
Parent: nil,
},
{
Id: "gemini-pro",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "gemini-pro",
Parent: nil,
},
{
Id: "chatglm_turbo",
Object: "model",
@@ -460,6 +487,24 @@ func init() {
Root: "qwen-plus",
Parent: nil,
},
{
Id: "qwen-max",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max",
Parent: nil,
},
{
Id: "qwen-max-longcontext",
Object: "model",
Created: 1677649963,
OwnedBy: "ali",
Permission: permission,
Root: "qwen-max-longcontext",
Parent: nil,
},
{
Id: "text-embedding-v1",
Object: "model",

View File

@@ -2,6 +2,7 @@ package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
@@ -24,6 +25,11 @@ func RelayChat(c *gin.Context) {
return
}
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())

View File

@@ -2,6 +2,7 @@ package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
@@ -24,6 +25,11 @@ func RelayCompletions(c *gin.Context) {
return
}
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())

View File

@@ -24,6 +24,10 @@ func RelayImageGenerations(c *gin.Context) {
imageRequest.Model = "dall-e-2"
}
if imageRequest.N == 0 {
imageRequest.N = 1
}
if imageRequest.Size == "" {
imageRequest.Size = "1024x1024"
}

View File

@@ -110,9 +110,14 @@ func setChannelToContext(c *gin.Context, channel *model.Channel) {
c.Set("api_version", channel.Other)
case common.ChannelTypeXunfei:
c.Set("api_version", channel.Other)
case common.ChannelTypeGemini:
c.Set("api_version", channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set("library_id", channel.Other)
case common.ChannelTypeAli:
c.Set("plugin", channel.Other)
}
}
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
@@ -131,8 +136,22 @@ func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
return false
}
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
err := model.PostConsumeTokenQuota(tokenId, quota)
func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
if !common.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}
func postConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
@@ -140,11 +159,15 @@ func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, c
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
model.UpdateChannelUsedQuota(channelId, quota)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}