mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
commit
1c371300ab
@ -64,6 +64,7 @@
|
|||||||
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
||||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
||||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||||
|
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
### 部署要求
|
||||||
|
@ -182,6 +182,7 @@ var defaultModelPrice = map[string]float64{
|
|||||||
"mj_describe": 0.05,
|
"mj_describe": 0.05,
|
||||||
"mj_upscale": 0.05,
|
"mj_upscale": 0.05,
|
||||||
"swap_face": 0.05,
|
"swap_face": 0.05,
|
||||||
|
"mj_upload": 0.05,
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||||
@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
|||||||
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
|
||||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
|
||||||
|
var GeminiModelMap = map[string]string{
|
||||||
|
"gemini-1.5-pro-latest": "v1beta",
|
||||||
|
"gemini-1.5-pro-001": "v1beta",
|
||||||
|
"gemini-1.5-pro": "v1beta",
|
||||||
|
"gemini-1.5-pro-exp-0801": "v1beta",
|
||||||
|
"gemini-1.5-flash-latest": "v1beta",
|
||||||
|
"gemini-1.5-flash-001": "v1beta",
|
||||||
|
"gemini-1.5-flash": "v1beta",
|
||||||
|
"gemini-ultra": "v1beta",
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitEnv() {
|
||||||
|
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||||
|
if modelVersionMapStr == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||||
|
parts := strings.Split(pair, ":")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
GeminiModelMap[parts[0]] = parts[1]
|
||||||
|
} else {
|
||||||
|
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -27,6 +27,7 @@ const (
|
|||||||
MjActionLowVariation = "LOW_VARIATION"
|
MjActionLowVariation = "LOW_VARIATION"
|
||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
MjActionSwapFace = "SWAP_FACE"
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
|
MjActionUpload = "UPLOAD"
|
||||||
)
|
)
|
||||||
|
|
||||||
var MidjourneyModel2Action = map[string]string{
|
var MidjourneyModel2Action = map[string]string{
|
||||||
@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
|
|||||||
"mj_low_variation": MjActionLowVariation,
|
"mj_low_variation": MjActionLowVariation,
|
||||||
"mj_pan": MjActionPan,
|
"mj_pan": MjActionPan,
|
||||||
"swap_face": MjActionSwapFace,
|
"swap_face": MjActionSwapFace,
|
||||||
|
"mj_upload": MjActionUpload,
|
||||||
}
|
}
|
||||||
|
@ -240,7 +240,7 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if !channel.GetAutoBan() {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
@ -39,44 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
retryTimes := common.RetryTimes
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
channelName := c.GetString("channel_name")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
openaiErr := relayHandler(c, relayMode)
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
|
||||||
if openaiErr != nil {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
go processChannelError(c, channelId, channelType, channelName, openaiErr)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
} else {
|
|
||||||
retryTimes = 0
|
|
||||||
}
|
|
||||||
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, err.Error())
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id))
|
|
||||||
c.Set("use_channel", useChannel)
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
openaiErr = relayRequest(c, relayMode, channel)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
openaiErr = relayHandler(c, relayMode)
|
if openaiErr == nil {
|
||||||
if openaiErr != nil {
|
return // 成功处理请求,直接返回
|
||||||
go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr)
|
}
|
||||||
|
|
||||||
|
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||||
|
|
||||||
|
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
@ -90,7 +82,42 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
return relayHandler(c, relayMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
|
c.Set("use_channel", useChannel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
||||||
|
if retryCount == 0 {
|
||||||
|
autoBan := c.GetBool("auto_ban")
|
||||||
|
autoBanInt := 1
|
||||||
|
if !autoBan {
|
||||||
|
autoBanInt = 0
|
||||||
|
}
|
||||||
|
return &model.Channel{
|
||||||
|
Id: c.GetInt("channel_id"),
|
||||||
|
Type: c.GetInt("channel_type"),
|
||||||
|
Name: c.GetString("channel_name"),
|
||||||
|
AutoBan: &autoBanInt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||||
|
}
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||||
if openaiErr == nil {
|
if openaiErr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -114,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
|
channelType := c.GetInt("channel_type")
|
||||||
|
if channelType == common.ChannelTypeAnthropic {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == 408 {
|
if openaiErr.StatusCode == 408 {
|
||||||
@ -129,9 +160,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
autoBan := c.GetBool("auto_ban")
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
|
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||||
}
|
}
|
||||||
@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
|
|||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, err := common.GetRequestBody(c)
|
||||||
@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
@ -806,11 +806,11 @@ type topUpRequest struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var lock = sync.Mutex{}
|
var topUpLock = sync.Mutex{}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
lock.Lock()
|
topUpLock.Lock()
|
||||||
defer lock.Unlock()
|
defer topUpLock.Unlock()
|
||||||
req := topUpRequest{}
|
req := topUpRequest{}
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -33,6 +33,12 @@ type MidjourneyResponse struct {
|
|||||||
Result string `json:"result"`
|
Result string `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MidjourneyUploadResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Result []string `json:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyResponseWithStatusCode struct {
|
type MidjourneyResponseWithStatusCode struct {
|
||||||
StatusCode int `json:"statusCode"`
|
StatusCode int `json:"statusCode"`
|
||||||
Response MidjourneyResponse
|
Response MidjourneyResponse
|
||||||
|
2
main.go
2
main.go
@ -55,6 +55,8 @@ func main() {
|
|||||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize constants
|
||||||
|
constant.InitEnv()
|
||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
|
@ -153,6 +153,12 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
}
|
}
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
|
if token != nil {
|
||||||
|
id := c.GetInt("id")
|
||||||
|
if id == 0 {
|
||||||
|
c.Set("id", token.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error())
|
||||||
return
|
return
|
||||||
|
@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("channel_type", channel.Type)
|
c.Set("channel_type", channel.Type)
|
||||||
ban := true
|
|
||||||
// parse *int to bool
|
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
|
||||||
ban = false
|
|
||||||
}
|
|
||||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||||
}
|
}
|
||||||
c.Set("auto_ban", ban)
|
c.Set("auto_ban", channel.GetAutoBan())
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
|||||||
channel.OtherInfo = string(otherInfoBytes)
|
channel.OtherInfo = string(otherInfoBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetAutoBan() bool {
|
||||||
|
if channel.AutoBan == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *channel.AutoBan == 1
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Save() error {
|
func (channel *Channel) Save() error {
|
||||||
return DB.Save(channel).Error
|
return DB.Save(channel).Error
|
||||||
}
|
}
|
||||||
|
27
model/log.go
27
model/log.go
@ -7,6 +7,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
@ -102,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
tx = DB.Where("type = ?", logType)
|
tx = DB.Where("type = ?", logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name like ?", "%"+modelName+"%")
|
tx = tx.Where("model_name like ?", modelName)
|
||||||
}
|
}
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
@ -137,7 +138,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name like ?", modelName)
|
||||||
}
|
}
|
||||||
if tokenName != "" {
|
if tokenName != "" {
|
||||||
tx = tx.Where("token_name = ?", tokenName)
|
tx = tx.Where("token_name = ?", tokenName)
|
||||||
@ -185,12 +186,18 @@ type Stat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
||||||
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
tx := DB.Table("logs").Select("sum(quota) quota")
|
||||||
|
|
||||||
|
// 为rpm和tpm创建单独的查询
|
||||||
|
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
||||||
|
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
if tokenName != "" {
|
if tokenName != "" {
|
||||||
tx = tx.Where("token_name = ?", tokenName)
|
tx = tx.Where("token_name = ?", tokenName)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
|
||||||
}
|
}
|
||||||
if startTimestamp != 0 {
|
if startTimestamp != 0 {
|
||||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||||
@ -200,11 +207,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
|
||||||
}
|
}
|
||||||
if channel != 0 {
|
if channel != 0 {
|
||||||
tx = tx.Where("channel_id = ?", channel)
|
tx = tx.Where("channel_id = ?", channel)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
|
|
||||||
|
tx = tx.Where("type = ?", LogTypeConsume)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
|
||||||
|
|
||||||
|
// 只统计最近60秒的rpm和tpm
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
tx.Scan(&stat)
|
||||||
|
rpmTpmQuery.Scan(&stat)
|
||||||
|
|
||||||
return stat
|
return stat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,12 +50,12 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
if token.Status == common.TokenStatusExhausted {
|
if token.Status == common.TokenStatusExhausted {
|
||||||
keyPrefix := key[:3]
|
keyPrefix := key[:3]
|
||||||
keySuffix := key[len(key)-3:]
|
keySuffix := key[len(key)-3:]
|
||||||
return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
|
return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]")
|
||||||
} else if token.Status == common.TokenStatusExpired {
|
} else if token.Status == common.TokenStatusExpired {
|
||||||
return nil, errors.New("该令牌已过期")
|
return token, errors.New("该令牌已过期")
|
||||||
}
|
}
|
||||||
if token.Status != common.TokenStatusEnabled {
|
if token.Status != common.TokenStatusEnabled {
|
||||||
return nil, errors.New("该令牌状态不可用")
|
return token, errors.New("该令牌状态不可用")
|
||||||
}
|
}
|
||||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
@ -65,7 +65,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
common.SysError("failed to update token status" + err.Error())
|
common.SysError("failed to update token status" + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("该令牌已过期")
|
return token, errors.New("该令牌已过期")
|
||||||
}
|
}
|
||||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
@ -78,7 +78,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
|||||||
}
|
}
|
||||||
keyPrefix := key[:3]
|
keyPrefix := key[:3]
|
||||||
keySuffix := key[len(key)-3:]
|
keySuffix := key[len(key)-3:]
|
||||||
return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
|
return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota))
|
||||||
}
|
}
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
|
||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
|
||||||
var modelVersionMap = map[string]string{
|
|
||||||
"gemini-1.5-pro-latest": "v1beta",
|
|
||||||
"gemini-1.5-flash-latest": "v1beta",
|
|
||||||
"gemini-ultra": "v1beta",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
|
||||||
if !beta {
|
if !beta {
|
||||||
if info.ApiVersion != "" {
|
if info.ApiVersion != "" {
|
||||||
version = info.ApiVersion
|
version = info.ApiVersion
|
||||||
|
@ -33,7 +33,7 @@ type RelayInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@ -112,7 +112,7 @@ type TaskRelayInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
|
@ -27,6 +27,7 @@ const (
|
|||||||
RelayModeMidjourneyModal
|
RelayModeMidjourneyModal
|
||||||
RelayModeMidjourneyShorten
|
RelayModeMidjourneyShorten
|
||||||
RelayModeSwapFace
|
RelayModeSwapFace
|
||||||
|
RelayModeMidjourneyUpload
|
||||||
|
|
||||||
RelayModeAudioSpeech // tts
|
RelayModeAudioSpeech // tts
|
||||||
RelayModeAudioTranscription // whisper
|
RelayModeAudioTranscription // whisper
|
||||||
@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
|
|||||||
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
||||||
// midjourney plus
|
// midjourney plus
|
||||||
relayMode = RelayModeSwapFace
|
relayMode = RelayModeSwapFace
|
||||||
|
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeMidjourneyUpload
|
||||||
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||||
relayMode = RelayModeMidjourneyImagine
|
relayMode = RelayModeMidjourneyImagine
|
||||||
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||||
|
@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
midjRequest.Action = constant.MjActionShorten
|
midjRequest.Action = constant.MjActionShorten
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionBlend
|
midjRequest.Action = constant.MjActionBlend
|
||||||
|
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||||||
|
midjRequest.Action = constant.MjActionUpload
|
||||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
mjId := ""
|
mjId := ""
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||||
@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("get_channel_null: " + err.Error())
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
responseBody = []byte(newBody)
|
responseBody = []byte(newBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
|
||||||
|
midjourneyTask.Progress = "100%"
|
||||||
|
midjourneyTask.Status = "SUCCESS"
|
||||||
|
}
|
||||||
err = midjourneyTask.Insert()
|
err = midjourneyTask.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||||
responseBody = []byte(newBody)
|
responseBody = []byte(newBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
|
@ -262,13 +262,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if tokenQuota > 100*preConsumedQuota {
|
if tokenQuota > 100*preConsumedQuota {
|
||||||
// 令牌额度充足,信任令牌
|
// 令牌额度充足,信任令牌
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
@ -295,7 +295,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
|||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||||
modelPrice float64, usePrice bool, extraContent string) {
|
modelPrice float64, usePrice bool, extraContent string) {
|
||||||
|
if usage == nil {
|
||||||
|
usage = &dto.Usage{
|
||||||
|
PromptTokens: relayInfo.PromptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: relayInfo.PromptTokens,
|
||||||
|
}
|
||||||
|
extraContent += " ,(可能是请求出错)"
|
||||||
|
}
|
||||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
|
@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
|||||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
|
|||||||
action = constant.MjActionModal
|
action = constant.MjActionModal
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
action = constant.MjActionSwapFace
|
action = constant.MjActionSwapFace
|
||||||
|
case relayconstant.RelayModeMidjourneyUpload:
|
||||||
|
action = constant.MjActionUpload
|
||||||
case relayconstant.RelayModeMidjourneySimpleChange:
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
}
|
}
|
||||||
var midjResponse dto.MidjourneyResponse
|
var midjResponse dto.MidjourneyResponse
|
||||||
|
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||||
@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||||
}
|
}
|
||||||
respStr := string(responseBody)
|
respStr := string(responseBody)
|
||||||
log.Printf("responseBody: %s", respStr)
|
log.Printf("respStr: %s", respStr)
|
||||||
if respStr == "" {
|
if respStr == "" {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||||
} else {
|
} else {
|
||||||
err = json.Unmarshal(responseBody, &midjResponse)
|
err = json.Unmarshal(responseBody, &midjResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
|
||||||
|
if err2 != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//log.Printf("midjResponse: %v", midjResponse)
|
//log.Printf("midjResponse: %v", midjResponse)
|
||||||
|
@ -2,6 +2,7 @@ import React, { useEffect, useState } from 'react';
|
|||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
copy,
|
copy,
|
||||||
|
getTodayStartTimestamp,
|
||||||
isAdmin,
|
isAdmin,
|
||||||
showError,
|
showError,
|
||||||
showSuccess,
|
showSuccess,
|
||||||
@ -412,19 +413,19 @@ const LogsTable = () => {
|
|||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [loadingStat, setLoadingStat] = useState(false);
|
const [loadingStat, setLoadingStat] = useState(false);
|
||||||
const [activePage, setActivePage] = useState(1);
|
const [activePage, setActivePage] = useState(1);
|
||||||
const [logCount, setLogCount] = useState(0);
|
const [logCount, setLogCount] = useState(ITEMS_PER_PAGE);
|
||||||
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE);
|
||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
const [searching, setSearching] = useState(false);
|
const [searching, setSearching] = useState(false);
|
||||||
const [logType, setLogType] = useState(0);
|
const [logType, setLogType] = useState(0);
|
||||||
const isAdminUser = isAdmin();
|
const isAdminUser = isAdmin();
|
||||||
let now = new Date();
|
let now = new Date();
|
||||||
// 初始化start_timestamp为前一天
|
// 初始化start_timestamp为今天0点
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
username: '',
|
username: '',
|
||||||
token_name: '',
|
token_name: '',
|
||||||
model_name: '',
|
model_name: '',
|
||||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
start_timestamp: timestamp2string(getTodayStartTimestamp()),
|
||||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||||
channel: '',
|
channel: '',
|
||||||
});
|
});
|
||||||
@ -449,9 +450,9 @@ const LogsTable = () => {
|
|||||||
const getLogSelfStat = async () => {
|
const getLogSelfStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(
|
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`,
|
url = encodeURI(url);
|
||||||
);
|
let res = await API.get(url);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@ -463,9 +464,9 @@ const LogsTable = () => {
|
|||||||
const getLogStat = async () => {
|
const getLogStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
let res = await API.get(
|
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`,
|
url = encodeURI(url);
|
||||||
);
|
let res = await API.get(url);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setStat(data);
|
setStat(data);
|
||||||
@ -475,6 +476,9 @@ const LogsTable = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleEyeClick = async () => {
|
const handleEyeClick = async () => {
|
||||||
|
if (loadingStat) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
setLoadingStat(true);
|
setLoadingStat(true);
|
||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
await getLogStat();
|
await getLogStat();
|
||||||
@ -509,14 +513,14 @@ const LogsTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const setLogsFormat = (logs, total) => {
|
const setLogsFormat = (logs) => {
|
||||||
for (let i = 0; i < logs.length; i++) {
|
for (let i = 0; i < logs.length; i++) {
|
||||||
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
logs[i].timestamp2string = timestamp2string(logs[i].created_at);
|
||||||
logs[i].key = '' + logs[i].id;
|
logs[i].key = '' + logs[i].id;
|
||||||
}
|
}
|
||||||
// data.key = '' + data.id
|
// data.key = '' + data.id
|
||||||
setLogs(logs);
|
setLogs(logs);
|
||||||
setLogCount(total);
|
setLogCount(logs.length + ITEMS_PER_PAGE);
|
||||||
// console.log(logCount);
|
// console.log(logCount);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -531,15 +535,16 @@ const LogsTable = () => {
|
|||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
|
url = encodeURI(url);
|
||||||
const res = await API.get(url);
|
const res = await API.get(url);
|
||||||
const { success, message, total, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
if (startIdx === 0) {
|
if (startIdx === 0) {
|
||||||
setLogsFormat(data, total);
|
setLogsFormat(data);
|
||||||
} else {
|
} else {
|
||||||
let newLogs = [...logs];
|
let newLogs = [...logs];
|
||||||
newLogs.splice(startIdx * pageSize, data.length, ...data);
|
newLogs.splice(startIdx * pageSize, data.length, ...data);
|
||||||
setLogsFormat(newLogs, total);
|
setLogsFormat(newLogs);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
@ -574,6 +579,7 @@ const LogsTable = () => {
|
|||||||
const refresh = async () => {
|
const refresh = async () => {
|
||||||
// setLoading(true);
|
// setLoading(true);
|
||||||
setActivePage(1);
|
setActivePage(1);
|
||||||
|
handleEyeClick();
|
||||||
await loadLogs(0, pageSize, logType);
|
await loadLogs(0, pageSize, logType);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -596,6 +602,7 @@ const LogsTable = () => {
|
|||||||
.catch((reason) => {
|
.catch((reason) => {
|
||||||
showError(reason);
|
showError(reason);
|
||||||
});
|
});
|
||||||
|
handleEyeClick();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const searchLogs = async () => {
|
const searchLogs = async () => {
|
||||||
@ -622,19 +629,17 @@ const LogsTable = () => {
|
|||||||
<Layout>
|
<Layout>
|
||||||
<Header>
|
<Header>
|
||||||
<Spin spinning={loadingStat}>
|
<Spin spinning={loadingStat}>
|
||||||
<h3>
|
<Space>
|
||||||
使用明细(总消耗额度:
|
<Tag color='green' size='large' style={{ padding: 15 }}>
|
||||||
<span
|
总消耗额度: {renderQuota(stat.quota)}
|
||||||
onClick={handleEyeClick}
|
</Tag>
|
||||||
style={{
|
<Tag color='blue' size='large' style={{ padding: 15 }}>
|
||||||
cursor: 'pointer',
|
RPM: {stat.rpm}
|
||||||
color: 'gray',
|
</Tag>
|
||||||
}}
|
<Tag color='purple' size='large' style={{ padding: 15 }}>
|
||||||
>
|
TPM: {stat.tpm}
|
||||||
{showStat ? renderQuota(stat.quota) : '点击查看'}
|
</Tag>
|
||||||
</span>
|
</Space>
|
||||||
)
|
|
||||||
</h3>
|
|
||||||
</Spin>
|
</Spin>
|
||||||
</Header>
|
</Header>
|
||||||
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
||||||
@ -700,20 +705,18 @@ const LogsTable = () => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
<Form.Section>
|
<Button
|
||||||
<Button
|
label='查询'
|
||||||
label='查询'
|
type='primary'
|
||||||
type='primary'
|
htmlType='submit'
|
||||||
htmlType='submit'
|
className='btn-margin-right'
|
||||||
className='btn-margin-right'
|
onClick={refresh}
|
||||||
onClick={() => {
|
loading={loading}
|
||||||
refresh(logType).then();
|
style={{ marginTop: 24 }}
|
||||||
}}
|
>
|
||||||
loading={loading}
|
查询
|
||||||
>
|
</Button>
|
||||||
查询
|
<Form.Section></Form.Section>
|
||||||
</Button>
|
|
||||||
</Form.Section>
|
|
||||||
</>
|
</>
|
||||||
</Form>
|
</Form>
|
||||||
<Table
|
<Table
|
||||||
|
@ -90,6 +90,12 @@ function renderType(type) {
|
|||||||
图混合
|
图混合
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'UPLOAD':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large'>
|
||||||
|
上传文件
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
case 'SHORTEN':
|
case 'SHORTEN':
|
||||||
return (
|
return (
|
||||||
<Tag color='pink' size='large'>
|
<Tag color='pink' size='large'>
|
||||||
|
@ -144,6 +144,12 @@ export function removeTrailingSlash(url) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getTodayStartTimestamp() {
|
||||||
|
var now = new Date();
|
||||||
|
now.setHours(0, 0, 0, 0);
|
||||||
|
return Math.floor(now.getTime() / 1000);
|
||||||
|
}
|
||||||
|
|
||||||
export function timestamp2string(timestamp) {
|
export function timestamp2string(timestamp) {
|
||||||
let date = new Date(timestamp * 1000);
|
let date = new Date(timestamp * 1000);
|
||||||
let year = date.getFullYear().toString();
|
let year = date.getFullYear().toString();
|
||||||
|
Loading…
Reference in New Issue
Block a user