diff --git a/README.md b/README.md index 67260f2..4a7f4c6 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 +- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 ## 部署 ### 部署要求 diff --git a/common/model-ratio.go b/common/model-ratio.go index 242ccd7..d4ea88b 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -182,6 +182,7 @@ var defaultModelPrice = map[string]float64{ "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05, + "mj_upload": 0.05, } var ( diff --git a/constant/env.go b/constant/env.go index 76146ca..dd3ae65 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,7 +1,10 @@ package constant import ( + "fmt" "one-api/common" + "os" + "strings" ) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) @@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) + +var GeminiModelMap = map[string]string{ + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", +} + +func InitEnv() { + modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) + if modelVersionMapStr == "" { + return + } + for _, pair := range strings.Split(modelVersionMapStr, ",") { + parts := strings.Split(pair, ":") + if len(parts) == 2 { + GeminiModelMap[parts[0]] = parts[1] + } else { + common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) + } + } +} diff --git a/constant/midjourney.go b/constant/midjourney.go index 1e9d30a..4fba479 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -27,6 +27,7 @@ const ( MjActionLowVariation = "LOW_VARIATION" MjActionPan = "PAN" MjActionSwapFace = "SWAP_FACE" + MjActionUpload = "UPLOAD" ) var MidjourneyModel2Action = map[string]string{ @@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{ "mj_low_variation": MjActionLowVariation, "mj_pan": MjActionPan, "swap_face": MjActionSwapFace, + "mj_upload": MjActionUpload, } diff --git a/controller/channel-test.go b/controller/channel-test.go index fe27978..95c4a60 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -240,7 +240,7 @@ func testAllChannels(notify bool) error { } // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { + if !channel.GetAutoBan() { ban = false } diff --git a/controller/relay.go b/controller/relay.go index 0c79015..13fbde0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) - retryTimes := common.RetryTimes requestId := c.GetString(common.RequestIdKey) - channelId := c.GetInt("channel_id") - channelType := c.GetInt("channel_type") - channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") - openaiErr := relayHandler(c, relayMode) - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - if openaiErr != nil { - go processChannelError(c, channelId, channelType, channelName, openaiErr) - } else { - retryTimes = 0 - } - for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, err.Error()) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) break } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id)) - c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - openaiErr = relayHandler(c, relayMode) - if openaiErr != nil { - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break } } useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if openaiErr != nil { @@ -90,7 +82,42 @@ func Relay(c *gin.Context) { } } -func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { +func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relayHandler(c, relayMode) +} + +func addUsedChannel(c *gin.Context, channelId int) { + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) +} + +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { + if retryCount == 0 { + autoBan := c.GetBool("auto_ban") + autoBanInt := 1 + if !autoBan { + autoBanInt = 0 + } + return &model.Channel{ + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + AutoBan: &autoBanInt, + }, nil + } + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + if err != nil { + return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + return channel, nil +} + +func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { if openaiErr == nil { return false } @@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } if openaiErr.StatusCode == http.StatusBadRequest { + channelType := c.GetInt("channel_type") + if channelType == common.ChannelTypeAnthropic { + return true + } return false } if openaiErr.StatusCode == 408 { @@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { - autoBan := c.GetBool("auto_ban") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) +func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously + common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) } @@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break } channelId = channel.Id useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) @@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/controller/user.go b/controller/user.go index 4048713..d040ee3 100644 --- a/controller/user.go +++ b/controller/user.go @@ -806,11 +806,11 @@ type topUpRequest struct { Key string `json:"key"` } -var lock = sync.Mutex{} +var topUpLock = sync.Mutex{} func TopUp(c *gin.Context) { - lock.Lock() - defer lock.Unlock() + topUpLock.Lock() + defer topUpLock.Unlock() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { diff --git a/dto/midjourney.go b/dto/midjourney.go index c675f7e..40251ee 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -33,6 +33,12 @@ type MidjourneyResponse struct { Result string `json:"result"` } +type MidjourneyUploadResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Result []string `json:"result"` +} + type MidjourneyResponseWithStatusCode struct { StatusCode int `json:"statusCode"` Response MidjourneyResponse diff --git a/main.go b/main.go index 959b795..efee772 100644 --- a/main.go +++ b/main.go @@ -55,6 +55,8 @@ func main() { common.FatalLog("failed to initialize Redis: " + err.Error()) } + // Initialize constants + constant.InitEnv() // Initialize options model.InitOptionMap() if common.RedisEnabled { diff --git a/middleware/auth.go b/middleware/auth.go index abf15d8..5215f7f 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -153,6 +153,12 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] } token, err := model.ValidateUserToken(key) + if token != nil { + id := c.GetInt("id") + if id == 0 { + c.Set("id", token.Id) + } + } if err != nil { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return diff --git a/middleware/distributor.go b/middleware/distributor.go index f150b41..1be3b31 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel == nil { return } - c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } - c.Set("auto_ban", ban) + c.Set("auto_ban", channel.GetAutoBan()) c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/model/channel.go b/model/channel.go index 7db3f07..3f9d9ed 100644 --- a/model/channel.go +++ b/model/channel.go @@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { channel.OtherInfo = string(otherInfoBytes) } +func (channel *Channel) GetAutoBan() bool { + if channel.AutoBan == nil { + return false + } + return *channel.AutoBan == 1 +} + func (channel *Channel) Save() error { return DB.Save(channel).Error } diff --git a/model/log.go b/model/log.go index 528a7e2..204e049 100644 --- a/model/log.go +++ b/model/log.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "one-api/common" "strings" + "time" ) type Log struct { @@ -102,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = DB.Where("type = ?", logType) } if modelName != "" { - tx = tx.Where("model_name like ?", "%"+modelName+"%") + tx = tx.Where("model_name like ?", modelName) } if username != "" { tx = tx.Where("username = ?", username) @@ -137,7 +138,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int tx = DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { - tx = tx.Where("model_name = ?", modelName) + tx = tx.Where("model_name like ?", modelName) } if tokenName != "" { tx = tx.Where("token_name = ?", tokenName) @@ -185,12 +186,18 @@ type Stat struct { } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { - tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") + tx := DB.Table("logs").Select("sum(quota) quota") + + // 为rpm和tpm创建单独的查询 + rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") + if username != "" { tx = tx.Where("username = ?", username) + rpmTpmQuery = rpmTpmQuery.Where("username = ?", username) } if tokenName != "" { tx = tx.Where("token_name = ?", tokenName) + rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName) } if startTimestamp != 0 { tx = tx.Where("created_at >= ?", startTimestamp) @@ -200,11 +207,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } if modelName != "" { tx = tx.Where("model_name = ?", modelName) + rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName) } if channel != 0 { tx = tx.Where("channel_id = ?", channel) + rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) } - tx.Where("type = ?", LogTypeConsume).Scan(&stat) + + tx = tx.Where("type = ?", LogTypeConsume) + rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume) + + // 只统计最近60秒的rpm和tpm + rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix()) + + // 执行查询 + tx.Scan(&stat) + rpmTpmQuery.Scan(&stat) + return stat } diff --git a/model/token.go b/model/token.go index d269c40..2919051 100644 --- a/model/token.go +++ b/model/token.go @@ -50,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status == common.TokenStatusExhausted { keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") + return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") } else if token.Status == common.TokenStatusExpired { - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if token.Status != common.TokenStatusEnabled { - return nil, errors.New("该令牌状态不可用") + return token, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if !common.RedisEnabled { @@ -65,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) { common.SysError("failed to update token status" + err.Error()) } } - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { @@ -78,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) + return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) } return token, nil } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e132d2f..4c4649f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "one-api/constant" "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" @@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} -// 定义一个映射,存储模型名称和对应的版本 -var modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-ultra": "v1beta", } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := modelVersionMap[info.UpstreamModelName] + version, beta := constant.GeminiModelMap[info.UpstreamModelName] if !beta { if info.ApiVersion != "" { version = info.ApiVersion diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 564a7ad..3ed5ee3 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -33,7 +33,7 @@ type RelayInfo struct { } func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") @@ -112,7 +112,7 @@ type TaskRelayInfo struct { } func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index a072c74..6006bc6 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -27,6 +27,7 @@ const ( RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + RelayModeMidjourneyUpload RelayModeAudioSpeech // tts RelayModeAudioTranscription // whisper @@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int { } else if strings.HasSuffix(path, "/mj/insight-face/swap") { // midjourney plus relayMode = RelayModeSwapFace + } else if strings.HasSuffix(path, "/submit/upload-discord-images") { + // midjourney plus + relayMode = RelayModeMidjourneyUpload } else if strings.HasSuffix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine } else if strings.HasSuffix(path, "/mj/submit/blend") { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 471dedb..4164e3a 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend + } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { @@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { common.SysError("get_channel_null: " + err.Error()) } - if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { + if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") } } @@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody = []byte(newBody) } } - + if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" { + midjourneyTask.Progress = "100%" + midjourneyTask.Status = "SUCCESS" + } err = midjourneyTask.Insert() if err != nil { return &dto.MidjourneyResponse{ @@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) responseBody = []byte(newBody) } - //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/relay/relay-text.go b/relay/relay-text.go index 089ea5e..7b9f7d9 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -262,13 +262,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) } } if preConsumedQuota > 0 { @@ -295,7 +295,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64, usePrice bool, extraContent string) { - + if usage == nil { + usage = &dto.Usage{ + PromptTokens: relayInfo.PromptTokens, + CompletionTokens: 0, + TotalTokens: relayInfo.PromptTokens, + } + extraContent += " ,(可能是请求出错)" + } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens diff --git a/router/relay-router.go b/router/relay-router.go index 2bf2ca2..b8d4a43 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) + relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney) } } diff --git a/service/midjourney.go b/service/midjourney.go index 6bb3a9e..d7b2a9a 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin action = constant.MjActionModal case relayconstant.RelayModeSwapFace: action = constant.MjActionSwapFace + case relayconstant.RelayModeMidjourneyUpload: + action = constant.MjActionUpload case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err } var midjResponse dto.MidjourneyResponse - + var midjourneyUploadsResponse dto.MidjourneyUploadResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err @@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err } respStr := string(responseBody) - log.Printf("responseBody: %s", respStr) + log.Printf("respStr: %s", respStr) if respStr == "" { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil } else { err = json.Unmarshal(responseBody, &midjResponse) if err != nil { - return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse) + if err2 != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } } } //log.Printf("midjResponse: %v", midjResponse) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index d578c15..127d738 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react'; import { API, copy, + getTodayStartTimestamp, isAdmin, showError, showSuccess, @@ -412,19 +413,19 @@ const LogsTable = () => { const [loading, setLoading] = useState(false); const [loadingStat, setLoadingStat] = useState(false); const [activePage, setActivePage] = useState(1); - const [logCount, setLogCount] = useState(0); + const [logCount, setLogCount] = useState(ITEMS_PER_PAGE); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); const [searchKeyword, setSearchKeyword] = useState(''); const [searching, setSearching] = useState(false); const [logType, setLogType] = useState(0); const isAdminUser = isAdmin(); let now = new Date(); - // 初始化start_timestamp为前一天 + // 初始化start_timestamp为今天0点 const [inputs, setInputs] = useState({ username: '', token_name: '', model_name: '', - start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), + start_timestamp: timestamp2string(getTodayStartTimestamp()), end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), channel: '', }); @@ -449,9 +450,9 @@ const LogsTable = () => { const getLogSelfStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; - let res = await API.get( - `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`, - ); + let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + url = encodeURI(url); + let res = await API.get(url); const { success, message, data } = res.data; if (success) { setStat(data); @@ -463,9 +464,9 @@ const LogsTable = () => { const getLogStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; - let res = await API.get( - `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`, - ); + let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; + url = encodeURI(url); + let res = await API.get(url); const { success, message, data } = res.data; if (success) { setStat(data); @@ -475,6 +476,9 @@ const LogsTable = () => { }; const handleEyeClick = async () => { + if (loadingStat) { + return; + } setLoadingStat(true); if (isAdminUser) { await getLogStat(); @@ -509,14 +513,14 @@ const LogsTable = () => { } }; - const setLogsFormat = (logs, total) => { + const setLogsFormat = (logs) => { for (let i = 0; i < logs.length; i++) { logs[i].timestamp2string = timestamp2string(logs[i].created_at); logs[i].key = '' + logs[i].id; } // data.key = '' + data.id setLogs(logs); - setLogCount(total); + setLogCount(logs.length + ITEMS_PER_PAGE); // console.log(logCount); }; @@ -531,15 +535,16 @@ const LogsTable = () => { } else { url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; } + url = encodeURI(url); const res = await API.get(url); - const { success, message, total, data } = res.data; + const { success, message, data } = res.data; if (success) { if (startIdx === 0) { - setLogsFormat(data, total); + setLogsFormat(data); } else { let newLogs = [...logs]; newLogs.splice(startIdx * pageSize, data.length, ...data); - setLogsFormat(newLogs, total); + setLogsFormat(newLogs); } } else { showError(message); @@ -574,6 +579,7 @@ const LogsTable = () => { const refresh = async () => { // setLoading(true); setActivePage(1); + handleEyeClick(); await loadLogs(0, pageSize, logType); }; @@ -596,6 +602,7 @@ const LogsTable = () => { .catch((reason) => { showError(reason); }); + handleEyeClick(); }, []); const searchLogs = async () => { @@ -622,19 +629,17 @@ const LogsTable = () => {
-

- 使用明细(总消耗额度: - - {showStat ? renderQuota(stat.quota) : '点击查看'} - - ) -

+ + + 总消耗额度: {renderQuota(stat.quota)} + + + RPM: {stat.rpm} + + + TPM: {stat.tpm} + +
@@ -700,20 +705,18 @@ const LogsTable = () => { /> )} - - - + +
); + case 'UPLOAD': + return ( + + 上传文件 + + ); case 'SHORTEN': return ( diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index 8af5f68..6c27790 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -144,6 +144,12 @@ export function removeTrailingSlash(url) { } } +export function getTodayStartTimestamp() { + var now = new Date(); + now.setHours(0, 0, 0, 0); + return Math.floor(now.getTime() / 1000); +} + export function timestamp2string(timestamp) { let date = new Date(timestamp * 1000); let year = date.getFullYear().toString();