mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-17 19:13:42 +08:00
Compare commits
11 Commits
v0.2.1.0-a
...
v0.2.1.0-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f188147680 | ||
|
|
0a49715c3d | ||
|
|
89efed48fc | ||
|
|
97e0aae0a7 | ||
|
|
320da09f36 | ||
|
|
2d849e0dd6 | ||
|
|
60d7ed3fb5 | ||
|
|
c5f6d0e063 | ||
|
|
a7cfce24d0 | ||
|
|
34bf8f8945 | ||
|
|
5961de03e7 |
@@ -201,16 +201,16 @@ func GetCompletionRatio(name string) float64 {
|
||||
return 4.0 / 3.0
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasSuffix(name, "preview") {
|
||||
if strings.HasSuffix(name, "gpt-4-turbo") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
if strings.Contains(name, "claude-instant-1") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-2") {
|
||||
} else if strings.Contains(name, "claude-2") {
|
||||
return 3
|
||||
} else if strings.HasPrefix(name, "claude-3") {
|
||||
} else if strings.Contains(name, "claude-3") {
|
||||
return 5
|
||||
}
|
||||
if strings.HasPrefix(name, "mistral-") {
|
||||
|
||||
@@ -2,6 +2,8 @@ package constant
|
||||
|
||||
var MjNotifyEnabled = false
|
||||
|
||||
var MjModeClearEnabled = false
|
||||
|
||||
const (
|
||||
MjErrorUnknown = 5
|
||||
MjRequestError = 4
|
||||
|
||||
@@ -92,6 +92,9 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode == 307 {
|
||||
return true
|
||||
}
|
||||
if openaiErr.StatusCode/100 == 5 {
|
||||
// 超时不重试
|
||||
if openaiErr.StatusCode == 504 || openaiErr.StatusCode == 524 {
|
||||
|
||||
@@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) {
|
||||
userId := c.GetInt("id")
|
||||
var channel *model.Channel
|
||||
channelId, ok := c.Get("specific_channel_id")
|
||||
modelRequest, shouldSelectChannel, err := getModelRequest(c)
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
@@ -40,72 +43,7 @@ func Distribute() func(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
shouldSelectChannel := true
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
// check token model mapping
|
||||
modelLimitEnable := c.GetBool("token_model_limit_enabled")
|
||||
if modelLimitEnable {
|
||||
@@ -128,8 +66,6 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
userGroup, _ := model.CacheGetUserGroup(userId)
|
||||
c.Set("group", userGroup)
|
||||
if shouldSelectChannel {
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
|
||||
if err != nil {
|
||||
@@ -147,14 +83,87 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
|
||||
return
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
}
|
||||
}
|
||||
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
||||
var modelRequest ModelRequest
|
||||
shouldSelectChannel := true
|
||||
var err error
|
||||
if strings.Contains(c.Request.URL.Path, "/mj/") {
|
||||
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
|
||||
if relayMode == relayconstant.RelayModeMidjourneyTaskFetch ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskFetchByCondition ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyNotify ||
|
||||
relayMode == relayconstant.RelayModeMidjourneyTaskImageSeed {
|
||||
shouldSelectChannel = false
|
||||
} else {
|
||||
midjourneyRequest := dto.MidjourneyRequest{}
|
||||
err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
|
||||
if err != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
|
||||
if mjErr != nil {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
|
||||
return nil, false, fmt.Errorf(mjErr.Description)
|
||||
}
|
||||
if midjourneyModel == "" {
|
||||
if !success {
|
||||
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
|
||||
return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
|
||||
} else {
|
||||
// task fetch, task fetch by condition, notify
|
||||
shouldSelectChannel = false
|
||||
}
|
||||
}
|
||||
modelRequest.Model = midjourneyModel
|
||||
}
|
||||
c.Set("relay_mode", relayMode)
|
||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
if err != nil {
|
||||
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
|
||||
return nil, false, err
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if modelRequest.Model == "" {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
modelRequest.Model = "tts-1"
|
||||
} else {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
}
|
||||
return &modelRequest, shouldSelectChannel, nil
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
c.Set("original_model", modelName) // for retry
|
||||
if channel == nil {
|
||||
return
|
||||
}
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
@@ -168,7 +177,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
}
|
||||
c.Set("auto_ban", ban)
|
||||
c.Set("model_mapping", channel.GetModelMapping())
|
||||
c.Set("original_model", modelName) // for retry
|
||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// TODO: api_version统一
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/samber/lo"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
@@ -134,7 +135,16 @@ func (channel *Channel) AddAbilities() error {
|
||||
abilities = append(abilities, ability)
|
||||
}
|
||||
}
|
||||
return DB.Create(&abilities).Error
|
||||
if len(abilities) == 0 {
|
||||
return nil
|
||||
}
|
||||
for _, chunk := range lo.Chunk(abilities, 50) {
|
||||
err := DB.Create(&chunk).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (channel *Channel) DeleteAbilities() error {
|
||||
|
||||
@@ -92,6 +92,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
|
||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||
@@ -195,6 +196,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.DefaultCollapseSidebar = boolValue
|
||||
case "MjNotifyEnabled":
|
||||
constant.MjNotifyEnabled = boolValue
|
||||
case "MjModeClearEnabled":
|
||||
constant.MjModeClearEnabled = boolValue
|
||||
case "CheckSensitiveEnabled":
|
||||
constant.CheckSensitiveEnabled = boolValue
|
||||
case "CheckSensitiveOnPromptEnabled":
|
||||
|
||||
@@ -57,6 +57,8 @@ func ShouldDisableChannel(err *relaymodel.OpenAIError, statusCode int) bool {
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
|
||||
return true
|
||||
} else if strings.HasPrefix(err.Message, "You exceeded your current quota") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -172,6 +172,15 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
// make new request with mapResult
|
||||
}
|
||||
if constant.MjModeClearEnabled {
|
||||
if prompt, ok := mapResult["prompt"].(string); ok {
|
||||
prompt = strings.Replace(prompt, "--fast", "", -1)
|
||||
prompt = strings.Replace(prompt, "--relax", "", -1)
|
||||
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
||||
|
||||
mapResult["prompt"] = prompt
|
||||
}
|
||||
}
|
||||
reqBody, err := json.Marshal(mapResult)
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
||||
|
||||
@@ -36,6 +36,7 @@ const OperationSetting = () => {
|
||||
StopOnSensitiveEnabled: '',
|
||||
SensitiveWords: '',
|
||||
MjNotifyEnabled: '',
|
||||
MjModeClearEnabled: '',
|
||||
DrawingEnabled: '',
|
||||
DataExportEnabled: '',
|
||||
DataExportDefaultTime: 'hour',
|
||||
@@ -312,6 +313,12 @@ const OperationSetting = () => {
|
||||
name='MjNotifyEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
<Form.Checkbox
|
||||
checked={inputs.MjModeClearEnabled === 'true'}
|
||||
label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
|
||||
name='MjModeClearEnabled'
|
||||
onChange={handleInputChange}
|
||||
/>
|
||||
</Form.Group>
|
||||
<Divider />
|
||||
<Header as='h3'>屏蔽词过滤设置</Header>
|
||||
|
||||
Reference in New Issue
Block a user