merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-08-06 15:54:28 +08:00
commit 1c371300ab
24 changed files with 242 additions and 112 deletions

View File

@ -64,6 +64,7 @@
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
## 部署
### 部署要求

View File

@ -182,6 +182,7 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05,
"mj_upscale": 0.05,
"swap_face": 0.05,
"mj_upload": 0.05,
}
var (

View File

@ -1,7 +1,10 @@
package constant
import (
"fmt"
"one-api/common"
"os"
"strings"
)
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}

View File

@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
)
var MidjourneyModel2Action = map[string]string{
@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan,
"swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
}

View File

@ -240,7 +240,7 @@ func testAllChannels(notify bool) error {
}
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
if !channel.GetAutoBan() {
ban = false
}

View File

@ -2,6 +2,7 @@ package controller
import (
"bytes"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
channelName := c.GetString("channel_name")
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, channelType, channelName, openaiErr)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode)
if openaiErr != nil {
go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr)
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
@ -90,7 +82,42 @@ func Relay(c *gin.Context) {
}
}
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil {
return false
}
@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false
}
if openaiErr.StatusCode == 408 {
@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true
}
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban")
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
// 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message)
}
@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break
}
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)
common.LogInfo(c, retryLogStr)
}
if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests {

View File

@ -806,11 +806,11 @@ type topUpRequest struct {
Key string `json:"key"`
}
var lock = sync.Mutex{}
var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) {
lock.Lock()
defer lock.Unlock()
topUpLock.Lock()
defer topUpLock.Unlock()
req := topUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {

View File

@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"`
}
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"`
Response MidjourneyResponse

View File

@ -55,6 +55,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error())
}
// Initialize constants
constant.InitEnv()
// Initialize options
model.InitOptionMap()
if common.RedisEnabled {

View File

@ -153,6 +153,12 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0]
}
token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.Id)
}
}
if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return

View File

@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil {
return
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization)
}
c.Set("auto_ban", ban)
c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes)
}
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error {
return DB.Save(channel).Error
}

View File

@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"one-api/common"
"strings"
"time"
)
type Log struct {
@ -102,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name like ?", "%"+modelName+"%")
tx = tx.Where("model_name like ?", modelName)
}
if username != "" {
tx = tx.Where("username = ?", username)
@ -137,7 +138,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
tx = tx.Where("model_name like ?", modelName)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
@ -185,12 +186,18 @@ type Stat struct {
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
tx := DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" {
tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
}
if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
}
if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp)
@ -200,11 +207,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
}
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat
}

View File

@ -50,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
return token, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
return token, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled {
@ -65,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error())
}
}
return nil, errors.New("该令牌已过期")
return token, errors.New("该令牌已过期")
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
@ -78,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
}
keyPrefix := key[:3]
keySuffix := key[len(key)-3:]
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
}
return token, nil
}

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName]
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta {
if info.ApiVersion != "" {
version = info.ApiVersion

View File

@ -33,7 +33,7 @@ type RelayInfo struct {
}
func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")
@ -112,7 +112,7 @@ type TaskRelayInfo struct {
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel")
channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id")

View File

@ -27,6 +27,7 @@ const (
RelayModeMidjourneyModal
RelayModeMidjourneyShorten
RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper
@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus
relayMode = RelayModeSwapFace
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
// midjourney plus
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") {

View File

@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil {
common.SysError("get_channel_null: " + err.Error())
}
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
}
}
@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody)
}
}
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
midjourneyTask.Progress = "100%"
midjourneyTask.Status = "SUCCESS"
}
err = midjourneyTask.Insert()
if err != nil {
return &dto.MidjourneyResponse{
@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))

View File

@ -262,13 +262,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
}
}
if preConsumedQuota > 0 {
@ -295,7 +295,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " ,(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens

View File

@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
}
}

View File

@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = constant.MjActionModal
case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneyUpload:
action = constant.MjActionUpload
case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
}
var midjResponse dto.MidjourneyResponse
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
}
respStr := string(responseBody)
log.Printf("responseBody: %s", respStr)
log.Printf("respStr: %s", respStr)
if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else {
err = json.Unmarshal(responseBody, &midjResponse)
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
if err2 != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
}
}
}
//log.Printf("midjResponse: %v", midjResponse)

View File

@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
import {
API,
copy,
getTodayStartTimestamp,
isAdmin,
showError,
showSuccess,
@ -412,19 +413,19 @@ const LogsTable = () => {
const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(0);
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin();
let now = new Date();
// 初始化start_timestamp为前一天
// 初始化start_timestamp为今天0点
const [inputs, setInputs] = useState({
username: '',
token_name: '',
model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
start_timestamp: timestamp2string(getTodayStartTimestamp()),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: '',
});
@ -449,9 +450,9 @@ const LogsTable = () => {
const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`,
);
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
url = encodeURI(url);
let res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@ -463,9 +464,9 @@ const LogsTable = () => {
const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get(
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`,
);
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
url = encodeURI(url);
let res = await API.get(url);
const { success, message, data } = res.data;
if (success) {
setStat(data);
@ -475,6 +476,9 @@ const LogsTable = () => {
};
const handleEyeClick = async () => {
if (loadingStat) {
return;
}
setLoadingStat(true);
if (isAdminUser) {
await getLogStat();
@ -509,14 +513,14 @@ const LogsTable = () => {
}
};
const setLogsFormat = (logs, total) => {
const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id;
}
// data.key = '' + data.id
setLogs(logs);
setLogCount(total);
setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount);
};
@ -531,15 +535,16 @@ const LogsTable = () => {
} else {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
}
url = encodeURI(url);
const res = await API.get(url);
const { success, message, total, data } = res.data;
const { success, message, data } = res.data;
if (success) {
if (startIdx === 0) {
setLogsFormat(data, total);
setLogsFormat(data);
} else {
let newLogs = [...logs];
newLogs.splice(startIdx * pageSize, data.length, ...data);
setLogsFormat(newLogs, total);
setLogsFormat(newLogs);
}
} else {
showError(message);
@ -574,6 +579,7 @@ const LogsTable = () => {
const refresh = async () => {
// setLoading(true);
setActivePage(1);
handleEyeClick();
await loadLogs(0, pageSize, logType);
};
@ -596,6 +602,7 @@ const LogsTable = () => {
.catch((reason) => {
showError(reason);
});
handleEyeClick();
}, []);
const searchLogs = async () => {
@ -622,19 +629,17 @@ const LogsTable = () => {
<Layout>
<Header>
<Spin spinning={loadingStat}>
<h3>
使用明细总消耗额度
<span
onClick={handleEyeClick}
style={{
cursor: 'pointer',
color: 'gray',
}}
>
{showStat ? renderQuota(stat.quota) : '点击查看'}
</span>
</h3>
<Space>
<Tag color='green' size='large' style={{ padding: 15 }}>
总消耗额度: {renderQuota(stat.quota)}
</Tag>
<Tag color='blue' size='large' style={{ padding: 15 }}>
RPM: {stat.rpm}
</Tag>
<Tag color='purple' size='large' style={{ padding: 15 }}>
TPM: {stat.tpm}
</Tag>
</Space>
</Spin>
</Header>
<Form layout='horizontal' style={{ marginTop: 10 }}>
@ -700,20 +705,18 @@ const LogsTable = () => {
/>
</>
)}
<Form.Section>
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
onClick={() => {
refresh(logType).then();
}}
loading={loading}
>
查询
</Button>
</Form.Section>
<Button
label='查询'
type='primary'
htmlType='submit'
className='btn-margin-right'
onClick={refresh}
loading={loading}
style={{ marginTop: 24 }}
>
查询
</Button>
<Form.Section></Form.Section>
</>
</Form>
<Table

View File

@ -90,6 +90,12 @@ function renderType(type) {
图混合
</Tag>
);
case 'UPLOAD':
return (
<Tag color='blue' size='large'>
上传文件
</Tag>
);
case 'SHORTEN':
return (
<Tag color='pink' size='large'>

View File

@ -144,6 +144,12 @@ export function removeTrailingSlash(url) {
}
}
export function getTodayStartTimestamp() {
var now = new Date();
now.setHours(0, 0, 0, 0);
return Math.floor(now.getTime() / 1000);
}
export function timestamp2string(timestamp) {
let date = new Date(timestamp * 1000);
let year = date.getFullYear().toString();