merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong 2024-08-06 15:54:28 +08:00
commit 1c371300ab
24 changed files with 242 additions and 112 deletions

View File

@ -64,6 +64,7 @@
- `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN`是统计图片token默认为 `true`关闭后将不再在本地计算图片token可能会导致和上游计费不同此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true` - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`情况下统计图片token默认为 `true`
- `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。 - `UPDATE_TASK`是否更新异步任务Midjourney、Suno默认为 `true`,关闭后将不会更新任务进度。
- `GEMINI_MODEL_MAP`Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
## 部署 ## 部署
### 部署要求 ### 部署要求

View File

@ -182,6 +182,7 @@ var defaultModelPrice = map[string]float64{
"mj_describe": 0.05, "mj_describe": 0.05,
"mj_upscale": 0.05, "mj_upscale": 0.05,
"swap_face": 0.05, "swap_face": 0.05,
"mj_upload": 0.05,
} }
var ( var (

View File

@ -1,7 +1,10 @@
package constant package constant
import ( import (
"fmt"
"one-api/common" "one-api/common"
"os"
"strings"
) )
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
var GeminiModelMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-pro-001": "v1beta",
"gemini-1.5-pro": "v1beta",
"gemini-1.5-pro-exp-0801": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-1.5-flash-001": "v1beta",
"gemini-1.5-flash": "v1beta",
"gemini-ultra": "v1beta",
}
func InitEnv() {
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
if modelVersionMapStr == "" {
return
}
for _, pair := range strings.Split(modelVersionMapStr, ",") {
parts := strings.Split(pair, ":")
if len(parts) == 2 {
GeminiModelMap[parts[0]] = parts[1]
} else {
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
}
}
}

View File

@ -27,6 +27,7 @@ const (
MjActionLowVariation = "LOW_VARIATION" MjActionLowVariation = "LOW_VARIATION"
MjActionPan = "PAN" MjActionPan = "PAN"
MjActionSwapFace = "SWAP_FACE" MjActionSwapFace = "SWAP_FACE"
MjActionUpload = "UPLOAD"
) )
var MidjourneyModel2Action = map[string]string{ var MidjourneyModel2Action = map[string]string{
@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
"mj_low_variation": MjActionLowVariation, "mj_low_variation": MjActionLowVariation,
"mj_pan": MjActionPan, "mj_pan": MjActionPan,
"swap_face": MjActionSwapFace, "swap_face": MjActionSwapFace,
"mj_upload": MjActionUpload,
} }

View File

@ -240,7 +240,7 @@ func testAllChannels(notify bool) error {
} }
// parse *int to bool // parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 { if !channel.GetAutoBan() {
ban = false ban = false
} }

View File

@ -2,6 +2,7 @@ package controller
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := constant.Path2RelayMode(c.Request.URL.Path)
retryTimes := common.RetryTimes
requestId := c.GetString(common.RequestIdKey) requestId := c.GetString(common.RequestIdKey)
channelId := c.GetInt("channel_id")
channelType := c.GetInt("channel_type")
channelName := c.GetString("channel_name")
group := c.GetString("group") group := c.GetString("group")
originalModel := c.GetString("original_model") originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode) var openaiErr *dto.OpenAIErrorWithStatusCode
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil { for i := 0; i <= common.RetryTimes; i++ {
go processChannelError(c, channelId, channelType, channelName, openaiErr) channel, err := getChannel(c, group, originalModel, i)
} else {
retryTimes = 0
}
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break break
} }
channelId = channel.Id
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) openaiErr = relayRequest(c, relayMode, channel)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
openaiErr = relayHandler(c, relayMode) if openaiErr == nil {
if openaiErr != nil { return // 成功处理请求,直接返回
go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) }
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
} }
} }
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr) common.LogInfo(c, retryLogStr)
} }
if openaiErr != nil { if openaiErr != nil {
@ -90,7 +82,42 @@ func Relay(c *gin.Context) {
} }
} }
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return relayHandler(c, relayMode)
}
func addUsedChannel(c *gin.Context, channelId int) {
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
}
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
if retryCount == 0 {
autoBan := c.GetBool("auto_ban")
autoBanInt := 1
if !autoBan {
autoBanInt = 0
}
return &model.Channel{
Id: c.GetInt("channel_id"),
Type: c.GetInt("channel_type"),
Name: c.GetString("channel_name"),
AutoBan: &autoBanInt,
}, nil
}
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
if err != nil {
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
}
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
return channel, nil
}
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
if openaiErr == nil { if openaiErr == nil {
return false return false
} }
@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true return true
} }
if openaiErr.StatusCode == http.StatusBadRequest { if openaiErr.StatusCode == http.StatusBadRequest {
channelType := c.GetInt("channel_type")
if channelType == common.ChannelTypeAnthropic {
return true
}
return false return false
} }
if openaiErr.StatusCode == 408 { if openaiErr.StatusCode == 408 {
@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
return true return true
} }
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
autoBan := c.GetBool("auto_ban") // 不要使用context获取渠道信息异步处理时可能会出现渠道信息不一致的情况
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
if service.ShouldDisableChannel(channelType, err) && autoBan { if service.ShouldDisableChannel(channelType, err) && autoBan {
service.DisableChannel(channelId, channelName, err.Error.Message) service.DisableChannel(channelId, channelName, err.Error.Message)
} }
@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
if err != nil { if err != nil {
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
break break
} }
channelId = channel.Id channelId = channel.Id
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel) c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, err := common.GetRequestBody(c)
@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
useChannel := c.GetStringSlice("use_channel") useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 { if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr) common.LogInfo(c, retryLogStr)
} }
if taskErr != nil { if taskErr != nil {
if taskErr.StatusCode == http.StatusTooManyRequests { if taskErr.StatusCode == http.StatusTooManyRequests {

View File

@ -806,11 +806,11 @@ type topUpRequest struct {
Key string `json:"key"` Key string `json:"key"`
} }
var lock = sync.Mutex{} var topUpLock = sync.Mutex{}
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
lock.Lock() topUpLock.Lock()
defer lock.Unlock() defer topUpLock.Unlock()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {

View File

@ -33,6 +33,12 @@ type MidjourneyResponse struct {
Result string `json:"result"` Result string `json:"result"`
} }
type MidjourneyUploadResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result []string `json:"result"`
}
type MidjourneyResponseWithStatusCode struct { type MidjourneyResponseWithStatusCode struct {
StatusCode int `json:"statusCode"` StatusCode int `json:"statusCode"`
Response MidjourneyResponse Response MidjourneyResponse

View File

@ -55,6 +55,8 @@ func main() {
common.FatalLog("failed to initialize Redis: " + err.Error()) common.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize constants
constant.InitEnv()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
if common.RedisEnabled { if common.RedisEnabled {

View File

@ -153,6 +153,12 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
} }
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if token != nil {
id := c.GetInt("id")
if id == 0 {
c.Set("id", token.Id)
}
}
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
return return

View File

@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
if channel == nil { if channel == nil {
return return
} }
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("channel_type", channel.Type) c.Set("channel_type", channel.Type)
ban := true
// parse *int to bool
if channel.AutoBan != nil && *channel.AutoBan == 0 {
ban = false
}
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
c.Set("channel_organization", *channel.OpenAIOrganization) c.Set("channel_organization", *channel.OpenAIOrganization)
} }
c.Set("auto_ban", ban) c.Set("auto_ban", channel.GetAutoBan())
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.GetModelMapping())
c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping())
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
channel.OtherInfo = string(otherInfoBytes) channel.OtherInfo = string(otherInfoBytes)
} }
func (channel *Channel) GetAutoBan() bool {
if channel.AutoBan == nil {
return false
}
return *channel.AutoBan == 1
}
func (channel *Channel) Save() error { func (channel *Channel) Save() error {
return DB.Save(channel).Error return DB.Save(channel).Error
} }

View File

@ -7,6 +7,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strings" "strings"
"time"
) )
type Log struct { type Log struct {
@ -102,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
tx = DB.Where("type = ?", logType) tx = DB.Where("type = ?", logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name like ?", "%"+modelName+"%") tx = tx.Where("model_name like ?", modelName)
} }
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
@ -137,7 +138,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
tx = DB.Where("user_id = ? and type = ?", userId, logType) tx = DB.Where("user_id = ? and type = ?", userId, logType)
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name like ?", modelName)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
@ -185,12 +186,18 @@ type Stat struct {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") tx := DB.Table("logs").Select("sum(quota) quota")
// 为rpm和tpm创建单独的查询
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
} }
if tokenName != "" { if tokenName != "" {
tx = tx.Where("token_name = ?", tokenName) tx = tx.Where("token_name = ?", tokenName)
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
} }
if startTimestamp != 0 { if startTimestamp != 0 {
tx = tx.Where("created_at >= ?", startTimestamp) tx = tx.Where("created_at >= ?", startTimestamp)
@ -200,11 +207,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
} }
if channel != 0 { if channel != 0 {
tx = tx.Where("channel_id = ?", channel) tx = tx.Where("channel_id = ?", channel)
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
} }
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
tx = tx.Where("type = ?", LogTypeConsume)
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
// 只统计最近60秒的rpm和tpm
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
// 执行查询
tx.Scan(&stat)
rpmTpmQuery.Scan(&stat)
return stat return stat
} }

View File

@ -50,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
if token.Status == common.TokenStatusExhausted { if token.Status == common.TokenStatusExhausted {
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
} else if token.Status == common.TokenStatusExpired { } else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期") return token, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return token, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
@ -65,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
common.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
} }
return nil, errors.New("该令牌已过期") return token, errors.New("该令牌已过期")
} }
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
@ -78,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
keyPrefix := key[:3] keyPrefix := key[:3]
keySuffix := key[len(key)-3:] keySuffix := key[len(key)-3:]
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
} }
return token, nil return token, nil
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
// 定义一个映射,存储模型名称和对应的版本
var modelVersionMap = map[string]string{
"gemini-1.5-pro-latest": "v1beta",
"gemini-1.5-flash-latest": "v1beta",
"gemini-ultra": "v1beta",
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
version, beta := modelVersionMap[info.UpstreamModelName] version, beta := constant.GeminiModelMap[info.UpstreamModelName]
if !beta { if !beta {
if info.ApiVersion != "" { if info.ApiVersion != "" {
version = info.ApiVersion version = info.ApiVersion

View File

@ -33,7 +33,7 @@ type RelayInfo struct {
} }
func GenRelayInfo(c *gin.Context) *RelayInfo { func GenRelayInfo(c *gin.Context) *RelayInfo {
channelType := c.GetInt("channel") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
@ -112,7 +112,7 @@ type TaskRelayInfo struct {
} }
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
channelType := c.GetInt("channel") channelType := c.GetInt("channel_type")
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")

View File

@ -27,6 +27,7 @@ const (
RelayModeMidjourneyModal RelayModeMidjourneyModal
RelayModeMidjourneyShorten RelayModeMidjourneyShorten
RelayModeSwapFace RelayModeSwapFace
RelayModeMidjourneyUpload
RelayModeAudioSpeech // tts RelayModeAudioSpeech // tts
RelayModeAudioTranscription // whisper RelayModeAudioTranscription // whisper
@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
} else if strings.HasSuffix(path, "/mj/insight-face/swap") { } else if strings.HasSuffix(path, "/mj/insight-face/swap") {
// midjourney plus // midjourney plus
relayMode = RelayModeSwapFace relayMode = RelayModeSwapFace
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
// midjourney plus
relayMode = RelayModeMidjourneyUpload
} else if strings.HasSuffix(path, "/mj/submit/imagine") { } else if strings.HasSuffix(path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine relayMode = RelayModeMidjourneyImagine
} else if strings.HasSuffix(path, "/mj/submit/blend") { } else if strings.HasSuffix(path, "/mj/submit/blend") {

View File

@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
midjRequest.Action = constant.MjActionShorten midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果 } else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := "" mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange { if relayMode == relayconstant.RelayModeMidjourneyChange {
@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if err != nil { if err != nil {
common.SysError("get_channel_null: " + err.Error()) common.SysError("get_channel_null: " + err.Error())
} }
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
} }
} }
@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
responseBody = []byte(newBody) responseBody = []byte(newBody)
} }
} }
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
midjourneyTask.Progress = "100%"
midjourneyTask.Status = "SUCCESS"
}
err = midjourneyTask.Insert() err = midjourneyTask.Insert()
if err != nil { if err != nil {
return &dto.MidjourneyResponse{ return &dto.MidjourneyResponse{
@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody) responseBody = []byte(newBody)
} }
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))

View File

@ -262,13 +262,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
if tokenQuota > 100*preConsumedQuota { if tokenQuota > 100*preConsumedQuota {
// 令牌额度充足,信任令牌 // 令牌额度充足,信任令牌
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
} }
} else { } else {
// in this case, we do not pre-consume quota // in this case, we do not pre-consume quota
// because the user has enough quota // because the user has enough quota
preConsumedQuota = 0 preConsumedQuota = 0
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
} }
} }
if preConsumedQuota > 0 { if preConsumedQuota > 0 {
@ -295,7 +295,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, usePrice bool, extraContent string) { modelPrice float64, usePrice bool, extraContent string) {
if usage == nil {
usage = &dto.Usage{
PromptTokens: relayInfo.PromptTokens,
CompletionTokens: 0,
TotalTokens: relayInfo.PromptTokens,
}
extraContent += " ,(可能是请求出错)"
}
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens

View File

@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
} }
} }

View File

@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
action = constant.MjActionModal action = constant.MjActionModal
case relayconstant.RelayModeSwapFace: case relayconstant.RelayModeSwapFace:
action = constant.MjActionSwapFace action = constant.MjActionSwapFace
case relayconstant.RelayModeMidjourneyUpload:
action = constant.MjActionUpload
case relayconstant.RelayModeMidjourneySimpleChange: case relayconstant.RelayModeMidjourneySimpleChange:
params := ConvertSimpleChangeParams(midjRequest.Content) params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil { if params == nil {
@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
} }
var midjResponse dto.MidjourneyResponse var midjResponse dto.MidjourneyResponse
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
@ -230,15 +232,18 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
} }
respStr := string(responseBody) respStr := string(responseBody)
log.Printf("responseBody: %s", respStr) log.Printf("respStr: %s", respStr)
if respStr == "" { if respStr == "" {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
} else { } else {
err = json.Unmarshal(responseBody, &midjResponse) err = json.Unmarshal(responseBody, &midjResponse)
if err != nil { if err != nil {
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
if err2 != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
} }
} }
}
//log.Printf("midjResponse: %v", midjResponse) //log.Printf("midjResponse: %v", midjResponse)
//for k, v := range resp.Header { //for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0]) // c.Writer.Header().Set(k, v[0])

View File

@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
import { import {
API, API,
copy, copy,
getTodayStartTimestamp,
isAdmin, isAdmin,
showError, showError,
showSuccess, showSuccess,
@ -412,19 +413,19 @@ const LogsTable = () => {
const [loading, setLoading] = useState(false); const [loading, setLoading] = useState(false);
const [loadingStat, setLoadingStat] = useState(false); const [loadingStat, setLoadingStat] = useState(false);
const [activePage, setActivePage] = useState(1); const [activePage, setActivePage] = useState(1);
const [logCount, setLogCount] = useState(0); const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
const [searchKeyword, setSearchKeyword] = useState(''); const [searchKeyword, setSearchKeyword] = useState('');
const [searching, setSearching] = useState(false); const [searching, setSearching] = useState(false);
const [logType, setLogType] = useState(0); const [logType, setLogType] = useState(0);
const isAdminUser = isAdmin(); const isAdminUser = isAdmin();
let now = new Date(); let now = new Date();
// 初始化start_timestamp为前一天 // 初始化start_timestamp为今天0点
const [inputs, setInputs] = useState({ const [inputs, setInputs] = useState({
username: '', username: '',
token_name: '', token_name: '',
model_name: '', model_name: '',
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), start_timestamp: timestamp2string(getTodayStartTimestamp()),
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
channel: '', channel: '',
}); });
@ -449,9 +450,9 @@ const LogsTable = () => {
const getLogSelfStat = async () => { const getLogSelfStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get( let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`, url = encodeURI(url);
); let res = await API.get(url);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setStat(data); setStat(data);
@ -463,9 +464,9 @@ const LogsTable = () => {
const getLogStat = async () => { const getLogStat = async () => {
let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localStartTimestamp = Date.parse(start_timestamp) / 1000;
let localEndTimestamp = Date.parse(end_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000;
let res = await API.get( let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`, url = encodeURI(url);
); let res = await API.get(url);
const { success, message, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
setStat(data); setStat(data);
@ -475,6 +476,9 @@ const LogsTable = () => {
}; };
const handleEyeClick = async () => { const handleEyeClick = async () => {
if (loadingStat) {
return;
}
setLoadingStat(true); setLoadingStat(true);
if (isAdminUser) { if (isAdminUser) {
await getLogStat(); await getLogStat();
@ -509,14 +513,14 @@ const LogsTable = () => {
} }
}; };
const setLogsFormat = (logs, total) => { const setLogsFormat = (logs) => {
for (let i = 0; i < logs.length; i++) { for (let i = 0; i < logs.length; i++) {
logs[i].timestamp2string = timestamp2string(logs[i].created_at); logs[i].timestamp2string = timestamp2string(logs[i].created_at);
logs[i].key = '' + logs[i].id; logs[i].key = '' + logs[i].id;
} }
// data.key = '' + data.id // data.key = '' + data.id
setLogs(logs); setLogs(logs);
setLogCount(total); setLogCount(logs.length + ITEMS_PER_PAGE);
// console.log(logCount); // console.log(logCount);
}; };
@ -531,15 +535,16 @@ const LogsTable = () => {
} else { } else {
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
} }
url = encodeURI(url);
const res = await API.get(url); const res = await API.get(url);
const { success, message, total, data } = res.data; const { success, message, data } = res.data;
if (success) { if (success) {
if (startIdx === 0) { if (startIdx === 0) {
setLogsFormat(data, total); setLogsFormat(data);
} else { } else {
let newLogs = [...logs]; let newLogs = [...logs];
newLogs.splice(startIdx * pageSize, data.length, ...data); newLogs.splice(startIdx * pageSize, data.length, ...data);
setLogsFormat(newLogs, total); setLogsFormat(newLogs);
} }
} else { } else {
showError(message); showError(message);
@ -574,6 +579,7 @@ const LogsTable = () => {
const refresh = async () => { const refresh = async () => {
// setLoading(true); // setLoading(true);
setActivePage(1); setActivePage(1);
handleEyeClick();
await loadLogs(0, pageSize, logType); await loadLogs(0, pageSize, logType);
}; };
@ -596,6 +602,7 @@ const LogsTable = () => {
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
}); });
handleEyeClick();
}, []); }, []);
const searchLogs = async () => { const searchLogs = async () => {
@ -622,19 +629,17 @@ const LogsTable = () => {
<Layout> <Layout>
<Header> <Header>
<Spin spinning={loadingStat}> <Spin spinning={loadingStat}>
<h3> <Space>
使用明细总消耗额度 <Tag color='green' size='large' style={{ padding: 15 }}>
<span 总消耗额度: {renderQuota(stat.quota)}
onClick={handleEyeClick} </Tag>
style={{ <Tag color='blue' size='large' style={{ padding: 15 }}>
cursor: 'pointer', RPM: {stat.rpm}
color: 'gray', </Tag>
}} <Tag color='purple' size='large' style={{ padding: 15 }}>
> TPM: {stat.tpm}
{showStat ? renderQuota(stat.quota) : '点击查看'} </Tag>
</span> </Space>
</h3>
</Spin> </Spin>
</Header> </Header>
<Form layout='horizontal' style={{ marginTop: 10 }}> <Form layout='horizontal' style={{ marginTop: 10 }}>
@ -700,20 +705,18 @@ const LogsTable = () => {
/> />
</> </>
)} )}
<Form.Section>
<Button <Button
label='查询' label='查询'
type='primary' type='primary'
htmlType='submit' htmlType='submit'
className='btn-margin-right' className='btn-margin-right'
onClick={() => { onClick={refresh}
refresh(logType).then();
}}
loading={loading} loading={loading}
style={{ marginTop: 24 }}
> >
查询 查询
</Button> </Button>
</Form.Section> <Form.Section></Form.Section>
</> </>
</Form> </Form>
<Table <Table

View File

@ -90,6 +90,12 @@ function renderType(type) {
图混合 图混合
</Tag> </Tag>
); );
case 'UPLOAD':
return (
<Tag color='blue' size='large'>
上传文件
</Tag>
);
case 'SHORTEN': case 'SHORTEN':
return ( return (
<Tag color='pink' size='large'> <Tag color='pink' size='large'>

View File

@ -144,6 +144,12 @@ export function removeTrailingSlash(url) {
} }
} }
export function getTodayStartTimestamp() {
var now = new Date();
now.setHours(0, 0, 0, 0);
return Math.floor(now.getTime() / 1000);
}
export function timestamp2string(timestamp) { export function timestamp2string(timestamp) {
let date = new Date(timestamp * 1000); let date = new Date(timestamp * 1000);
let year = date.getFullYear().toString(); let year = date.getFullYear().toString();