mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-18 11:33:42 +08:00
Compare commits
43 Commits
v0.2.7.4-a
...
v0.2.7.5-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5d0d268c97 | ||
|
|
0b4ef42d86 | ||
|
|
0123ad4d61 | ||
|
|
5acf074541 | ||
|
|
8af0d9f22f | ||
|
|
afd328efcf | ||
|
|
dd12a0052f | ||
|
|
fbe6cd75b1 | ||
|
|
8a9ff36fbf | ||
|
|
88ba8a840e | ||
|
|
e504665f68 | ||
|
|
54657ec27b | ||
|
|
ae6b4e0be2 | ||
|
|
fc0db4505c | ||
|
|
22a98c5879 | ||
|
|
f8f15bd1d0 | ||
|
|
b7690fe17d | ||
|
|
58b4c237a4 | ||
|
|
54f6e660f1 | ||
|
|
3b1745c712 | ||
|
|
c92ab3b569 | ||
|
|
1501ccb919 | ||
|
|
7f2a2a7de0 | ||
|
|
cce7d0258f | ||
|
|
c5e8d7ec20 | ||
|
|
fe16d51fe4 | ||
|
|
2100d8ee0c | ||
|
|
fbce36238e | ||
|
|
a6b6bcfe00 | ||
|
|
07e55cc999 | ||
|
|
b16e6bf423 | ||
|
|
b7bc205b73 | ||
|
|
88cc88c5d0 | ||
|
|
ab1d61d910 | ||
|
|
d4a5df7373 | ||
|
|
9e610c9429 | ||
|
|
da490db6d3 | ||
|
|
b8291dcd13 | ||
|
|
b0d9756c14 | ||
|
|
9dc07a8585 | ||
|
|
caaecb8d54 | ||
|
|
b9454c3f14 | ||
|
|
96bdf97194 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,4 +5,5 @@ upload
|
|||||||
*.db
|
*.db
|
||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
|
web/dist
|
||||||
@@ -64,6 +64,7 @@
|
|||||||
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。
|
||||||
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。
|
||||||
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。
|
||||||
|
- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 部署要求
|
### 部署要求
|
||||||
|
|||||||
32
common/email-outlook-auth.go
Normal file
32
common/email-outlook-auth.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/smtp"
|
||||||
|
)
|
||||||
|
|
||||||
|
type outlookAuth struct {
|
||||||
|
username, password string
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginAuth(username, password string) smtp.Auth {
|
||||||
|
return &outlookAuth{username, password}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *outlookAuth) Start(_ *smtp.ServerInfo) (string, []byte, error) {
|
||||||
|
return "LOGIN", []byte{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *outlookAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||||
|
if more {
|
||||||
|
switch string(fromServer) {
|
||||||
|
case "Username:":
|
||||||
|
return []byte(a.username), nil
|
||||||
|
case "Password:":
|
||||||
|
return []byte(a.password), nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("unknown fromServer")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
@@ -62,6 +62,9 @@ func SendEmail(subject string, receiver string, content string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
} else if strings.HasSuffix(SMTPAccount, "outlook.com") {
|
||||||
|
auth = LoginAuth(SMTPAccount, SMTPToken)
|
||||||
|
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
|
||||||
} else {
|
} else {
|
||||||
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
|
err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// from songquanpeng/one-api
|
// from songquanpeng/one-api
|
||||||
@@ -180,10 +181,17 @@ var defaultModelPrice = map[string]float64{
|
|||||||
"mj_describe": 0.05,
|
"mj_describe": 0.05,
|
||||||
"mj_upscale": 0.05,
|
"mj_upscale": 0.05,
|
||||||
"swap_face": 0.05,
|
"swap_face": 0.05,
|
||||||
|
"mj_upload": 0.05,
|
||||||
}
|
}
|
||||||
|
|
||||||
var modelPrice map[string]float64 = nil
|
var (
|
||||||
var modelRatio map[string]float64 = nil
|
modelPriceMap = make(map[string]float64)
|
||||||
|
modelPriceMapMutex = sync.RWMutex{}
|
||||||
|
)
|
||||||
|
var (
|
||||||
|
modelRatioMap map[string]float64 = nil
|
||||||
|
modelRatioMapMutex = sync.RWMutex{}
|
||||||
|
)
|
||||||
|
|
||||||
var CompletionRatio map[string]float64 = nil
|
var CompletionRatio map[string]float64 = nil
|
||||||
var defaultCompletionRatio = map[string]float64{
|
var defaultCompletionRatio = map[string]float64{
|
||||||
@@ -191,11 +199,18 @@ var defaultCompletionRatio = map[string]float64{
|
|||||||
"gpt-4-all": 2,
|
"gpt-4-all": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelPrice2JSONString() string {
|
func GetModelPriceMap() map[string]float64 {
|
||||||
if modelPrice == nil {
|
modelPriceMapMutex.Lock()
|
||||||
modelPrice = defaultModelPrice
|
defer modelPriceMapMutex.Unlock()
|
||||||
|
if modelPriceMap == nil {
|
||||||
|
modelPriceMap = defaultModelPrice
|
||||||
}
|
}
|
||||||
jsonBytes, err := json.Marshal(modelPrice)
|
return modelPriceMap
|
||||||
|
}
|
||||||
|
|
||||||
|
func ModelPrice2JSONString() string {
|
||||||
|
GetModelPriceMap()
|
||||||
|
jsonBytes, err := json.Marshal(modelPriceMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model price: " + err.Error())
|
SysError("error marshalling model price: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -203,19 +218,19 @@ func ModelPrice2JSONString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateModelPriceByJSONString(jsonStr string) error {
|
func UpdateModelPriceByJSONString(jsonStr string) error {
|
||||||
modelPrice = make(map[string]float64)
|
modelPriceMapMutex.Lock()
|
||||||
return json.Unmarshal([]byte(jsonStr), &modelPrice)
|
defer modelPriceMapMutex.Unlock()
|
||||||
|
modelPriceMap = make(map[string]float64)
|
||||||
|
return json.Unmarshal([]byte(jsonStr), &modelPriceMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||||
if modelPrice == nil {
|
GetModelPriceMap()
|
||||||
modelPrice = defaultModelPrice
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
price, ok := modelPrice[name]
|
price, ok := modelPriceMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
if printErr {
|
if printErr {
|
||||||
SysError("model price not found: " + name)
|
SysError("model price not found: " + name)
|
||||||
@@ -225,18 +240,18 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|||||||
return price, true
|
return price, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelPriceMap() map[string]float64 {
|
func GetModelRatioMap() map[string]float64 {
|
||||||
if modelPrice == nil {
|
modelRatioMapMutex.Lock()
|
||||||
modelPrice = defaultModelPrice
|
defer modelRatioMapMutex.Unlock()
|
||||||
|
if modelRatioMap == nil {
|
||||||
|
modelRatioMap = defaultModelRatio
|
||||||
}
|
}
|
||||||
return modelPrice
|
return modelRatioMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
if modelRatio == nil {
|
GetModelRatioMap()
|
||||||
modelRatio = defaultModelRatio
|
jsonBytes, err := json.Marshal(modelRatioMap)
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(modelRatio)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
SysError("error marshalling model ratio: " + err.Error())
|
SysError("error marshalling model ratio: " + err.Error())
|
||||||
}
|
}
|
||||||
@@ -244,18 +259,18 @@ func ModelRatio2JSONString() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||||
modelRatio = make(map[string]float64)
|
modelRatioMapMutex.Lock()
|
||||||
return json.Unmarshal([]byte(jsonStr), &modelRatio)
|
defer modelRatioMapMutex.Unlock()
|
||||||
|
modelRatioMap = make(map[string]float64)
|
||||||
|
return json.Unmarshal([]byte(jsonStr), &modelRatioMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelRatio(name string) float64 {
|
func GetModelRatio(name string) float64 {
|
||||||
if modelRatio == nil {
|
GetModelRatioMap()
|
||||||
modelRatio = defaultModelRatio
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
ratio, ok := modelRatio[name]
|
ratio, ok := modelRatioMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
SysError("model ratio not found: " + name)
|
SysError("model ratio not found: " + name)
|
||||||
return 30
|
return 30
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30)
|
||||||
@@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true)
|
|||||||
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true)
|
||||||
|
|
||||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
|
||||||
|
|
||||||
|
var GeminiModelMap = map[string]string{
|
||||||
|
"gemini-1.5-pro-latest": "v1beta",
|
||||||
|
"gemini-1.5-pro-001": "v1beta",
|
||||||
|
"gemini-1.5-pro": "v1beta",
|
||||||
|
"gemini-1.5-pro-exp-0801": "v1beta",
|
||||||
|
"gemini-1.5-flash-latest": "v1beta",
|
||||||
|
"gemini-1.5-flash-001": "v1beta",
|
||||||
|
"gemini-1.5-flash": "v1beta",
|
||||||
|
"gemini-ultra": "v1beta",
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitEnv() {
|
||||||
|
modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP"))
|
||||||
|
if modelVersionMapStr == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, pair := range strings.Split(modelVersionMapStr, ",") {
|
||||||
|
parts := strings.Split(pair, ":")
|
||||||
|
if len(parts) == 2 {
|
||||||
|
GeminiModelMap[parts[0]] = parts[1]
|
||||||
|
} else {
|
||||||
|
common.SysError(fmt.Sprintf("invalid model version map: %s", pair))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ const (
|
|||||||
MjActionLowVariation = "LOW_VARIATION"
|
MjActionLowVariation = "LOW_VARIATION"
|
||||||
MjActionPan = "PAN"
|
MjActionPan = "PAN"
|
||||||
MjActionSwapFace = "SWAP_FACE"
|
MjActionSwapFace = "SWAP_FACE"
|
||||||
|
MjActionUpload = "UPLOAD"
|
||||||
)
|
)
|
||||||
|
|
||||||
var MidjourneyModel2Action = map[string]string{
|
var MidjourneyModel2Action = map[string]string{
|
||||||
@@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{
|
|||||||
"mj_low_variation": MjActionLowVariation,
|
"mj_low_variation": MjActionLowVariation,
|
||||||
"mj_pan": MjActionPan,
|
"mj_pan": MjActionPan,
|
||||||
"swap_face": MjActionSwapFace,
|
"swap_face": MjActionSwapFace,
|
||||||
|
"mj_upload": MjActionUpload,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -240,7 +240,7 @@ func testAllChannels(notify bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// parse *int to bool
|
// parse *int to bool
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
if !channel.GetAutoBan() {
|
||||||
ban = false
|
ban = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package controller
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
@@ -39,43 +40,35 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
func Relay(c *gin.Context) {
|
func Relay(c *gin.Context) {
|
||||||
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
|
||||||
retryTimes := common.RetryTimes
|
|
||||||
requestId := c.GetString(common.RequestIdKey)
|
requestId := c.GetString(common.RequestIdKey)
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
channelType := c.GetInt("channel_type")
|
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
openaiErr := relayHandler(c, relayMode)
|
var openaiErr *dto.OpenAIErrorWithStatusCode
|
||||||
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
|
|
||||||
if openaiErr != nil {
|
for i := 0; i <= common.RetryTimes; i++ {
|
||||||
go processChannelError(c, channelId, channelType, openaiErr)
|
channel, err := getChannel(c, group, originalModel, i)
|
||||||
} else {
|
|
||||||
retryTimes = 0
|
|
||||||
}
|
|
||||||
for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ {
|
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, err.Error())
|
||||||
|
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
|
||||||
c.Set("use_channel", useChannel)
|
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
openaiErr = relayRequest(c, relayMode, channel)
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
|
||||||
openaiErr = relayHandler(c, relayMode)
|
if openaiErr == nil {
|
||||||
if openaiErr != nil {
|
return // 成功处理请求,直接返回
|
||||||
go processChannelError(c, channelId, channel.Type, openaiErr)
|
}
|
||||||
|
|
||||||
|
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
|
||||||
|
|
||||||
|
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
@@ -89,7 +82,42 @@ func Relay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
|
||||||
|
addUsedChannel(c, channel.Id)
|
||||||
|
requestBody, _ := common.GetRequestBody(c)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||||
|
return relayHandler(c, relayMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addUsedChannel(c *gin.Context, channelId int) {
|
||||||
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
|
c.Set("use_channel", useChannel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) {
|
||||||
|
if retryCount == 0 {
|
||||||
|
autoBan := c.GetBool("auto_ban")
|
||||||
|
autoBanInt := 1
|
||||||
|
if !autoBan {
|
||||||
|
autoBanInt = 0
|
||||||
|
}
|
||||||
|
return &model.Channel{
|
||||||
|
Id: c.GetInt("channel_id"),
|
||||||
|
Type: c.GetInt("channel_type"),
|
||||||
|
Name: c.GetString("channel_name"),
|
||||||
|
AutoBan: &autoBanInt,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error()))
|
||||||
|
}
|
||||||
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
return channel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool {
|
||||||
if openaiErr == nil {
|
if openaiErr == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -113,6 +141,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
|
channelType := c.GetInt("channel_type")
|
||||||
|
if channelType == common.ChannelTypeAnthropic {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if openaiErr.StatusCode == 408 {
|
if openaiErr.StatusCode == 408 {
|
||||||
@@ -128,11 +160,11 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func processChannelError(c *gin.Context, channelId int, channelType int, err *dto.OpenAIErrorWithStatusCode) {
|
func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
autoBan := c.GetBool("auto_ban")
|
// 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
// do not use context to get channel info, there may be inconsistent channel info when processing asynchronously
|
||||||
|
common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message))
|
||||||
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
if service.ShouldDisableChannel(channelType, err) && autoBan {
|
||||||
channelName := c.GetString("channel_name")
|
|
||||||
service.DisableChannel(channelId, channelName, err.Error.Message)
|
service.DisableChannel(channelId, channelName, err.Error.Message)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -208,14 +240,14 @@ func RelayTask(c *gin.Context) {
|
|||||||
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ {
|
||||||
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error()))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
|
||||||
c.Set("use_channel", useChannel)
|
c.Set("use_channel", useChannel)
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
requestBody, err := common.GetRequestBody(c)
|
requestBody, err := common.GetRequestBody(c)
|
||||||
@@ -225,7 +257,7 @@ func RelayTask(c *gin.Context) {
|
|||||||
useChannel := c.GetStringSlice("use_channel")
|
useChannel := c.GetStringSlice("use_channel")
|
||||||
if len(useChannel) > 1 {
|
if len(useChannel) > 1 {
|
||||||
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
common.LogInfo(c, retryLogStr)
|
||||||
}
|
}
|
||||||
if taskErr != nil {
|
if taskErr != nil {
|
||||||
if taskErr.StatusCode == http.StatusTooManyRequests {
|
if taskErr.StatusCode == http.StatusTooManyRequests {
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ func RequestEpay(c *gin.Context) {
|
|||||||
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
returnUrl, _ := url.Parse(constant.ServerAddress + "/log")
|
||||||
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
|
||||||
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
|
||||||
|
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)
|
||||||
client := GetEpayClient()
|
client := GetEpayClient()
|
||||||
if client == nil {
|
if client == nil {
|
||||||
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"})
|
||||||
@@ -101,8 +102,8 @@ func RequestEpay(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||||
Type: payType,
|
Type: payType,
|
||||||
ServiceTradeNo: "A" + tradeNo,
|
ServiceTradeNo: tradeNo,
|
||||||
Name: "B" + tradeNo,
|
Name: fmt.Sprintf("TUC%d", req.Amount),
|
||||||
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
Money: strconv.FormatFloat(payMoney, 'f', 2, 64),
|
||||||
Device: epay.PC,
|
Device: epay.PC,
|
||||||
NotifyUrl: notifyUrl,
|
NotifyUrl: notifyUrl,
|
||||||
@@ -120,7 +121,7 @@ func RequestEpay(c *gin.Context) {
|
|||||||
UserId: id,
|
UserId: id,
|
||||||
Amount: amount,
|
Amount: amount,
|
||||||
Money: payMoney,
|
Money: payMoney,
|
||||||
TradeNo: "A" + tradeNo,
|
TradeNo: tradeNo,
|
||||||
CreateTime: time.Now().Unix(),
|
CreateTime: time.Now().Unix(),
|
||||||
Status: "pending",
|
Status: "pending",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -791,11 +791,11 @@ type topUpRequest struct {
|
|||||||
Key string `json:"key"`
|
Key string `json:"key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var lock = sync.Mutex{}
|
var topUpLock = sync.Mutex{}
|
||||||
|
|
||||||
func TopUp(c *gin.Context) {
|
func TopUp(c *gin.Context) {
|
||||||
lock.Lock()
|
topUpLock.Lock()
|
||||||
defer lock.Unlock()
|
defer topUpLock.Unlock()
|
||||||
req := topUpRequest{}
|
req := topUpRequest{}
|
||||||
err := c.ShouldBindJSON(&req)
|
err := c.ShouldBindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -33,6 +33,12 @@ type MidjourneyResponse struct {
|
|||||||
Result string `json:"result"`
|
Result string `json:"result"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MidjourneyUploadResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Result []string `json:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
type MidjourneyResponseWithStatusCode struct {
|
type MidjourneyResponseWithStatusCode struct {
|
||||||
StatusCode int `json:"statusCode"`
|
StatusCode int `json:"statusCode"`
|
||||||
Response MidjourneyResponse
|
Response MidjourneyResponse
|
||||||
|
|||||||
2
main.go
2
main.go
@@ -55,6 +55,8 @@ func main() {
|
|||||||
common.FatalLog("failed to initialize Redis: " + err.Error())
|
common.FatalLog("failed to initialize Redis: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize constants
|
||||||
|
constant.InitEnv()
|
||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
|
|||||||
@@ -184,19 +184,13 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
if channel == nil {
|
if channel == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set("channel", channel.Type)
|
|
||||||
c.Set("channel_id", channel.Id)
|
c.Set("channel_id", channel.Id)
|
||||||
c.Set("channel_name", channel.Name)
|
c.Set("channel_name", channel.Name)
|
||||||
c.Set("channel_type", channel.Type)
|
c.Set("channel_type", channel.Type)
|
||||||
ban := true
|
|
||||||
// parse *int to bool
|
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
|
||||||
ban = false
|
|
||||||
}
|
|
||||||
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization {
|
||||||
c.Set("channel_organization", *channel.OpenAIOrganization)
|
c.Set("channel_organization", *channel.OpenAIOrganization)
|
||||||
}
|
}
|
||||||
c.Set("auto_ban", ban)
|
c.Set("auto_ban", channel.GetAutoBan())
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
||||||
|
userId := c.GetInt("id")
|
||||||
c.JSON(statusCode, gin.H{
|
c.JSON(statusCode, gin.H{
|
||||||
"error": gin.H{
|
"error": gin.H{
|
||||||
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
"message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)),
|
||||||
@@ -13,7 +15,7 @@ func abortWithOpenAiMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
common.LogError(c.Request.Context(), message)
|
common.LogError(c.Request.Context(), fmt.Sprintf("user %d | %s", userId, message))
|
||||||
}
|
}
|
||||||
|
|
||||||
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
func abortWithMidjourneyMessage(c *gin.Context, statusCode int, code int, description string) {
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) {
|
|||||||
channel.OtherInfo = string(otherInfoBytes)
|
channel.OtherInfo = string(otherInfoBytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetAutoBan() bool {
|
||||||
|
if channel.AutoBan == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *channel.AutoBan == 1
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Save() error {
|
func (channel *Channel) Save() error {
|
||||||
return DB.Save(channel).Error
|
return DB.Save(channel).Error
|
||||||
}
|
}
|
||||||
@@ -100,8 +107,8 @@ func SearchChannels(keyword string, group string, model string) ([]*Channel, err
|
|||||||
var whereClause string
|
var whereClause string
|
||||||
var args []interface{}
|
var args []interface{}
|
||||||
if group != "" {
|
if group != "" {
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?"
|
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " = ? AND " + modelsCol + " LIKE ?"
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, group, "%"+model+"%")
|
||||||
} else {
|
} else {
|
||||||
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
|
whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?"
|
||||||
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
|
args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%")
|
||||||
|
|||||||
27
model/log.go
27
model/log.go
@@ -7,6 +7,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
@@ -102,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
tx = DB.Where("type = ?", logType)
|
tx = DB.Where("type = ?", logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name like ?", "%"+modelName+"%")
|
tx = tx.Where("model_name like ?", modelName)
|
||||||
}
|
}
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
@@ -131,7 +132,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
|
|||||||
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
tx = DB.Where("user_id = ? and type = ?", userId, logType)
|
||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name like ?", modelName)
|
||||||
}
|
}
|
||||||
if tokenName != "" {
|
if tokenName != "" {
|
||||||
tx = tx.Where("token_name = ?", tokenName)
|
tx = tx.Where("token_name = ?", tokenName)
|
||||||
@@ -172,12 +173,18 @@ type Stat struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) {
|
||||||
tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
tx := DB.Table("logs").Select("sum(quota) quota")
|
||||||
|
|
||||||
|
// 为rpm和tpm创建单独的查询
|
||||||
|
rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm")
|
||||||
|
|
||||||
if username != "" {
|
if username != "" {
|
||||||
tx = tx.Where("username = ?", username)
|
tx = tx.Where("username = ?", username)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("username = ?", username)
|
||||||
}
|
}
|
||||||
if tokenName != "" {
|
if tokenName != "" {
|
||||||
tx = tx.Where("token_name = ?", tokenName)
|
tx = tx.Where("token_name = ?", tokenName)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName)
|
||||||
}
|
}
|
||||||
if startTimestamp != 0 {
|
if startTimestamp != 0 {
|
||||||
tx = tx.Where("created_at >= ?", startTimestamp)
|
tx = tx.Where("created_at >= ?", startTimestamp)
|
||||||
@@ -187,11 +194,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
}
|
}
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName)
|
||||||
}
|
}
|
||||||
if channel != 0 {
|
if channel != 0 {
|
||||||
tx = tx.Where("channel_id = ?", channel)
|
tx = tx.Where("channel_id = ?", channel)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
tx.Where("type = ?", LogTypeConsume).Scan(&stat)
|
|
||||||
|
tx = tx.Where("type = ?", LogTypeConsume)
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume)
|
||||||
|
|
||||||
|
// 只统计最近60秒的rpm和tpm
|
||||||
|
rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix())
|
||||||
|
|
||||||
|
// 执行查询
|
||||||
|
tx.Scan(&stat)
|
||||||
|
rpmTpmQuery.Scan(&stat)
|
||||||
|
|
||||||
return stat
|
return stat
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -222,9 +222,11 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
service.Done(c)
|
service.Done(c)
|
||||||
err = resp.Body.Close()
|
if resp != nil {
|
||||||
if err != nil {
|
err = resp.Body.Close()
|
||||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil, &usage
|
return nil, &usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package cloudflare
|
package cloudflare
|
||||||
|
|
||||||
var ModelList = []string{
|
var ModelList = []string{
|
||||||
|
"@cf/meta/llama-3.1-8b-instruct",
|
||||||
"@cf/meta/llama-2-7b-chat-fp16",
|
"@cf/meta/llama-2-7b-chat-fp16",
|
||||||
"@cf/meta/llama-2-7b-chat-int8",
|
"@cf/meta/llama-2-7b-chat-int8",
|
||||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
|
|||||||
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n")
|
||||||
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
} else if constant.DifyDebug && difyResponse.Event == "node_started" {
|
||||||
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n")
|
||||||
} else if difyResponse.Event == "message" {
|
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
|
||||||
choice.Delta.SetContentString(difyResponse.Answer)
|
choice.Delta.SetContentString(difyResponse.Answer)
|
||||||
}
|
}
|
||||||
response.Choices = append(response.Choices, choice)
|
response.Choices = append(response.Choices, choice)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
@@ -25,18 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
}
|
|
||||||
|
|
||||||
// 定义一个映射,存储模型名称和对应的版本
|
|
||||||
var modelVersionMap = map[string]string{
|
|
||||||
"gemini-1.5-pro-latest": "v1beta",
|
|
||||||
"gemini-1.5-flash-latest": "v1beta",
|
|
||||||
"gemini-ultra": "v1beta",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
// 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1"
|
||||||
version, beta := modelVersionMap[info.UpstreamModelName]
|
version, beta := constant.GeminiModelMap[info.UpstreamModelName]
|
||||||
if !beta {
|
if !beta {
|
||||||
if info.ApiVersion != "" {
|
if info.ApiVersion != "" {
|
||||||
version = info.ApiVersion
|
version = info.ApiVersion
|
||||||
|
|||||||
@@ -83,13 +83,28 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques
|
|||||||
if imageNum > GeminiVisionMaxImageNum {
|
if imageNum > GeminiVisionMaxImageNum {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
// 判断是否是url
|
||||||
parts = append(parts, GeminiPart{
|
if strings.HasPrefix(part.ImageUrl.(dto.MessageImageUrl).Url, "http") {
|
||||||
InlineData: &GeminiInlineData{
|
// 是url,获取图片的类型和base64编码的数据
|
||||||
MimeType: mimeType,
|
mimeType, data, _ := service.GetImageFromUrl(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
Data: data,
|
parts = append(parts, GeminiPart{
|
||||||
},
|
InlineData: &GeminiInlineData{
|
||||||
})
|
MimeType: mimeType,
|
||||||
|
Data: data,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
_, format, base64String, err := service.DecodeBase64ImageData(part.ImageUrl.(dto.MessageImageUrl).Url)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts = append(parts, GeminiPart{
|
||||||
|
InlineData: &GeminiInlineData{
|
||||||
|
MimeType: "image/" + format,
|
||||||
|
Data: base64String,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
content.Parts = parts
|
content.Parts = parts
|
||||||
|
|||||||
@@ -3,14 +3,18 @@ package ollama
|
|||||||
import "one-api/dto"
|
import "one-api/dto"
|
||||||
|
|
||||||
type OllamaRequest struct {
|
type OllamaRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []dto.Message `json:"messages,omitempty"`
|
Messages []dto.Message `json:"messages,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Seed float64 `json:"seed,omitempty"`
|
Seed float64 `json:"seed,omitempty"`
|
||||||
Topp float64 `json:"top_p,omitempty"`
|
Topp float64 `json:"top_p,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
Stop any `json:"stop,omitempty"`
|
Stop any `json:"stop,omitempty"`
|
||||||
|
Tools []dto.ToolCall `json:"tools,omitempty"`
|
||||||
|
ResponseFormat *dto.ResponseFormat `json:"response_format,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OllamaEmbeddingRequest struct {
|
type OllamaEmbeddingRequest struct {
|
||||||
@@ -21,6 +25,3 @@ type OllamaEmbeddingRequest struct {
|
|||||||
type OllamaEmbeddingResponse struct {
|
type OllamaEmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding,omitempty"`
|
Embedding []float64 `json:"embedding,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
//type OllamaOptions struct {
|
|
||||||
//}
|
|
||||||
|
|||||||
@@ -28,14 +28,18 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
|||||||
Stop, _ = request.Stop.([]string)
|
Stop, _ = request.Stop.([]string)
|
||||||
}
|
}
|
||||||
return &OllamaRequest{
|
return &OllamaRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
Stream: request.Stream,
|
Stream: request.Stream,
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
Seed: request.Seed,
|
Seed: request.Seed,
|
||||||
Topp: request.TopP,
|
Topp: request.TopP,
|
||||||
TopK: request.TopK,
|
TopK: request.TopK,
|
||||||
Stop: Stop,
|
Stop: Stop,
|
||||||
|
Tools: request.Tools,
|
||||||
|
ResponseFormat: request.ResponseFormat,
|
||||||
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
|
PresencePenalty: request.PresencePenalty,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
|
|
||||||
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
containStreamUsage := false
|
containStreamUsage := false
|
||||||
responseId := ""
|
var responseId string
|
||||||
var createAt int64 = 0
|
var createAt int64 = 0
|
||||||
var systemFingerprint string
|
var systemFingerprint string
|
||||||
model := info.UpstreamModelName
|
model := info.UpstreamModelName
|
||||||
@@ -86,7 +86,13 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
var lastStreamResponse dto.ChatCompletionsStreamResponse
|
||||||
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if lastStreamResponse.Usage != nil && service.ValidUsage(lastStreamResponse.Usage) {
|
responseId = lastStreamResponse.Id
|
||||||
|
createAt = lastStreamResponse.Created
|
||||||
|
systemFingerprint = lastStreamResponse.GetSystemFingerprint()
|
||||||
|
model = lastStreamResponse.Model
|
||||||
|
if service.ValidUsage(lastStreamResponse.Usage) {
|
||||||
|
containStreamUsage = true
|
||||||
|
usage = lastStreamResponse.Usage
|
||||||
if !info.ShouldIncludeUsage {
|
if !info.ShouldIncludeUsage {
|
||||||
shouldSendLastResp = false
|
shouldSendLastResp = false
|
||||||
}
|
}
|
||||||
@@ -109,14 +115,9 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
var streamResponse dto.ChatCompletionsStreamResponse
|
var streamResponse dto.ChatCompletionsStreamResponse
|
||||||
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
responseId = streamResponse.Id
|
//if service.ValidUsage(streamResponse.Usage) {
|
||||||
createAt = streamResponse.Created
|
// usage = streamResponse.Usage
|
||||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
//}
|
||||||
model = streamResponse.Model
|
|
||||||
if service.ValidUsage(streamResponse.Usage) {
|
|
||||||
usage = streamResponse.Usage
|
|
||||||
containStreamUsage = true
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
@@ -133,14 +134,10 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for _, streamResponse := range streamResponses {
|
for _, streamResponse := range streamResponses {
|
||||||
responseId = streamResponse.Id
|
//if service.ValidUsage(streamResponse.Usage) {
|
||||||
createAt = streamResponse.Created
|
// usage = streamResponse.Usage
|
||||||
systemFingerprint = streamResponse.GetSystemFingerprint()
|
// containStreamUsage = true
|
||||||
model = streamResponse.Model
|
//}
|
||||||
if service.ValidUsage(streamResponse.Usage) {
|
|
||||||
usage = streamResponse.Usage
|
|
||||||
containStreamUsage = true
|
|
||||||
}
|
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
responseTextBuilder.WriteString(choice.Delta.GetContentString())
|
||||||
if choice.Delta.ToolCalls != nil {
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ type RelayInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
@@ -112,7 +112,7 @@ type TaskRelayInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel_type")
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ const (
|
|||||||
RelayModeMidjourneyModal
|
RelayModeMidjourneyModal
|
||||||
RelayModeMidjourneyShorten
|
RelayModeMidjourneyShorten
|
||||||
RelayModeSwapFace
|
RelayModeSwapFace
|
||||||
|
RelayModeMidjourneyUpload
|
||||||
|
|
||||||
RelayModeAudioSpeech // tts
|
RelayModeAudioSpeech // tts
|
||||||
RelayModeAudioTranscription // whisper
|
RelayModeAudioTranscription // whisper
|
||||||
@@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int {
|
|||||||
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
} else if strings.HasSuffix(path, "/mj/insight-face/swap") {
|
||||||
// midjourney plus
|
// midjourney plus
|
||||||
relayMode = RelayModeSwapFace
|
relayMode = RelayModeSwapFace
|
||||||
|
} else if strings.HasSuffix(path, "/submit/upload-discord-images") {
|
||||||
|
// midjourney plus
|
||||||
|
relayMode = RelayModeMidjourneyUpload
|
||||||
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
} else if strings.HasSuffix(path, "/mj/submit/imagine") {
|
||||||
relayMode = RelayModeMidjourneyImagine
|
relayMode = RelayModeMidjourneyImagine
|
||||||
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
} else if strings.HasSuffix(path, "/mj/submit/blend") {
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
quota := int(modelPrice*groupRatio*common.QuotaPerUnit*sizeRatio*qualityRatio) * imageRequest.N
|
imageRatio := modelPrice * sizeRatio * qualityRatio * float64(imageRequest.N)
|
||||||
|
quota := int(imageRatio * groupRatio * common.QuotaPerUnit)
|
||||||
|
|
||||||
if userQuota-quota < 0 {
|
if userQuota-quota < 0 {
|
||||||
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
@@ -180,7 +181,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
|
|
||||||
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality)
|
||||||
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent)
|
postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, imageRatio, true, logContent)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
midjRequest.Action = constant.MjActionShorten
|
midjRequest.Action = constant.MjActionShorten
|
||||||
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||||||
midjRequest.Action = constant.MjActionBlend
|
midjRequest.Action = constant.MjActionBlend
|
||||||
|
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||||||
|
midjRequest.Action = constant.MjActionUpload
|
||||||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||||||
mjId := ""
|
mjId := ""
|
||||||
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
if relayMode == relayconstant.RelayModeMidjourneyChange {
|
||||||
@@ -547,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("get_channel_null: " + err.Error())
|
common.SysError("get_channel_null: " + err.Error())
|
||||||
}
|
}
|
||||||
if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled {
|
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||||||
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
responseBody = []byte(newBody)
|
responseBody = []byte(newBody)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
|
||||||
|
midjourneyTask.Progress = "100%"
|
||||||
|
midjourneyTask.Status = "SUCCESS"
|
||||||
|
}
|
||||||
err = midjourneyTask.Insert()
|
err = midjourneyTask.Insert()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &dto.MidjourneyResponse{
|
return &dto.MidjourneyResponse{
|
||||||
@@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
|
|||||||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||||||
responseBody = []byte(newBody)
|
responseBody = []byte(newBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
|
|
||||||
|
|||||||
@@ -253,13 +253,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo
|
|||||||
if tokenQuota > 100*preConsumedQuota {
|
if tokenQuota > 100*preConsumedQuota {
|
||||||
// 令牌额度充足,信任令牌
|
// 令牌额度充足,信任令牌
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// in this case, we do not pre-consume quota
|
// in this case, we do not pre-consume quota
|
||||||
// because the user has enough quota
|
// because the user has enough quota
|
||||||
preConsumedQuota = 0
|
preConsumedQuota = 0
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if preConsumedQuota > 0 {
|
if preConsumedQuota > 0 {
|
||||||
@@ -286,7 +286,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
|
|||||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||||
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
|
||||||
modelPrice float64, usePrice bool, extraContent string) {
|
modelPrice float64, usePrice bool, extraContent string) {
|
||||||
|
if usage == nil {
|
||||||
|
usage = &dto.Usage{
|
||||||
|
PromptTokens: relayInfo.PromptTokens,
|
||||||
|
CompletionTokens: 0,
|
||||||
|
TotalTokens: relayInfo.PromptTokens,
|
||||||
|
}
|
||||||
|
extraContent += " ,(可能是请求出错)"
|
||||||
|
}
|
||||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||||
promptTokens := usage.PromptTokens
|
promptTokens := usage.PromptTokens
|
||||||
completionTokens := usage.CompletionTokens
|
completionTokens := usage.CompletionTokens
|
||||||
|
|||||||
@@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
|||||||
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney)
|
||||||
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney)
|
||||||
|
relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin
|
|||||||
action = constant.MjActionModal
|
action = constant.MjActionModal
|
||||||
case relayconstant.RelayModeSwapFace:
|
case relayconstant.RelayModeSwapFace:
|
||||||
action = constant.MjActionSwapFace
|
action = constant.MjActionSwapFace
|
||||||
|
case relayconstant.RelayModeMidjourneyUpload:
|
||||||
|
action = constant.MjActionUpload
|
||||||
case relayconstant.RelayModeMidjourneySimpleChange:
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
||||||
params := ConvertSimpleChangeParams(midjRequest.Content)
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
||||||
if params == nil {
|
if params == nil {
|
||||||
@@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
||||||
}
|
}
|
||||||
var midjResponse dto.MidjourneyResponse
|
var midjResponse dto.MidjourneyResponse
|
||||||
|
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||||
@@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
|||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err
|
||||||
}
|
}
|
||||||
respStr := string(responseBody)
|
respStr := string(responseBody)
|
||||||
log.Printf("responseBody: %s", respStr)
|
log.Printf("respStr: %s", respStr)
|
||||||
if respStr == "" {
|
if respStr == "" {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
||||||
} else {
|
} else {
|
||||||
err = json.Unmarshal(responseBody, &midjResponse)
|
err = json.Unmarshal(responseBody, &midjResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
|
||||||
|
if err2 != nil {
|
||||||
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
//log.Printf("midjResponse: %v", midjResponse)
|
//log.Printf("midjResponse: %v", midjResponse)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import React, { useContext, useEffect, useState } from 'react';
|
import React, { useContext, useEffect, useState } from 'react';
|
||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showInfo, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showInfo, showSuccess, updateAPI } from '../helpers';
|
||||||
import { onGitHubOAuthClicked } from './utils';
|
import { onGitHubOAuthClicked } from './utils';
|
||||||
import Turnstile from 'react-turnstile';
|
import Turnstile from 'react-turnstile';
|
||||||
import {
|
import {
|
||||||
@@ -101,6 +101,7 @@ const LoginForm = () => {
|
|||||||
if (success) {
|
if (success) {
|
||||||
userDispatch({ type: 'login', payload: data });
|
userDispatch({ type: 'login', payload: data });
|
||||||
setUserData(data);
|
setUserData(data);
|
||||||
|
updateAPI()
|
||||||
showSuccess('登录成功!');
|
showSuccess('登录成功!');
|
||||||
if (username === 'root' && password === '123456') {
|
if (username === 'root' && password === '123456') {
|
||||||
Modal.error({
|
Modal.error({
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import {
|
import {
|
||||||
API,
|
API,
|
||||||
copy,
|
copy, getTodayStartTimestamp,
|
||||||
isAdmin,
|
isAdmin,
|
||||||
showError,
|
showError,
|
||||||
showSuccess,
|
showSuccess,
|
||||||
timestamp2string,
|
timestamp2string
|
||||||
} from '../helpers';
|
} from '../helpers';
|
||||||
|
|
||||||
import {
|
import {
|
||||||
@@ -419,12 +419,12 @@ const LogsTable = () => {
|
|||||||
const [logType, setLogType] = useState(0);
|
const [logType, setLogType] = useState(0);
|
||||||
const isAdminUser = isAdmin();
|
const isAdminUser = isAdmin();
|
||||||
let now = new Date();
|
let now = new Date();
|
||||||
// 初始化start_timestamp为前一天
|
// 初始化start_timestamp为今天0点
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
username: '',
|
username: '',
|
||||||
token_name: '',
|
token_name: '',
|
||||||
model_name: '',
|
model_name: '',
|
||||||
start_timestamp: timestamp2string(now.getTime() / 1000 - 86400),
|
start_timestamp: timestamp2string(getTodayStartTimestamp()),
|
||||||
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
end_timestamp: timestamp2string(now.getTime() / 1000 + 3600),
|
||||||
channel: '',
|
channel: '',
|
||||||
});
|
});
|
||||||
@@ -449,8 +449,10 @@ const LogsTable = () => {
|
|||||||
const getLogSelfStat = async () => {
|
const getLogSelfStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
|
let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
|
url = encodeURI(url);
|
||||||
let res = await API.get(
|
let res = await API.get(
|
||||||
`/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`,
|
url,
|
||||||
);
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -463,8 +465,10 @@ const LogsTable = () => {
|
|||||||
const getLogStat = async () => {
|
const getLogStat = async () => {
|
||||||
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
let localStartTimestamp = Date.parse(start_timestamp) / 1000;
|
||||||
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
let localEndTimestamp = Date.parse(end_timestamp) / 1000;
|
||||||
|
let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`;
|
||||||
|
url = encodeURI(url);
|
||||||
let res = await API.get(
|
let res = await API.get(
|
||||||
`/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`,
|
url,
|
||||||
);
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -475,6 +479,9 @@ const LogsTable = () => {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const handleEyeClick = async () => {
|
const handleEyeClick = async () => {
|
||||||
|
if (loadingStat) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
setLoadingStat(true);
|
setLoadingStat(true);
|
||||||
if (isAdminUser) {
|
if (isAdminUser) {
|
||||||
await getLogStat();
|
await getLogStat();
|
||||||
@@ -531,6 +538,7 @@ const LogsTable = () => {
|
|||||||
} else {
|
} else {
|
||||||
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`;
|
||||||
}
|
}
|
||||||
|
url = encodeURI(url);
|
||||||
const res = await API.get(url);
|
const res = await API.get(url);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -574,6 +582,7 @@ const LogsTable = () => {
|
|||||||
const refresh = async () => {
|
const refresh = async () => {
|
||||||
// setLoading(true);
|
// setLoading(true);
|
||||||
setActivePage(1);
|
setActivePage(1);
|
||||||
|
handleEyeClick();
|
||||||
await loadLogs(0, pageSize, logType);
|
await loadLogs(0, pageSize, logType);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -596,6 +605,7 @@ const LogsTable = () => {
|
|||||||
.catch((reason) => {
|
.catch((reason) => {
|
||||||
showError(reason);
|
showError(reason);
|
||||||
});
|
});
|
||||||
|
handleEyeClick();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const searchLogs = async () => {
|
const searchLogs = async () => {
|
||||||
@@ -622,19 +632,17 @@ const LogsTable = () => {
|
|||||||
<Layout>
|
<Layout>
|
||||||
<Header>
|
<Header>
|
||||||
<Spin spinning={loadingStat}>
|
<Spin spinning={loadingStat}>
|
||||||
<h3>
|
<Space>
|
||||||
使用明细(总消耗额度:
|
<Tag color='green' size='large' style={{ padding: 15 }}>
|
||||||
<span
|
总消耗额度: {renderQuota(stat.quota)}
|
||||||
onClick={handleEyeClick}
|
</Tag>
|
||||||
style={{
|
<Tag color='blue' size='large' style={{ padding: 15 }}>
|
||||||
cursor: 'pointer',
|
RPM: {stat.rpm}
|
||||||
color: 'gray',
|
</Tag>
|
||||||
}}
|
<Tag color='purple' size='large' style={{ padding: 15 }}>
|
||||||
>
|
TPM: {stat.tpm}
|
||||||
{showStat ? renderQuota(stat.quota) : '点击查看'}
|
</Tag>
|
||||||
</span>
|
</Space>
|
||||||
)
|
|
||||||
</h3>
|
|
||||||
</Spin>
|
</Spin>
|
||||||
</Header>
|
</Header>
|
||||||
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
<Form layout='horizontal' style={{ marginTop: 10 }}>
|
||||||
@@ -700,17 +708,19 @@ const LogsTable = () => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
<Button
|
||||||
|
label='查询'
|
||||||
|
type='primary'
|
||||||
|
htmlType='submit'
|
||||||
|
className='btn-margin-right'
|
||||||
|
onClick={refresh}
|
||||||
|
loading={loading}
|
||||||
|
style={{ marginTop: 24 }}
|
||||||
|
>
|
||||||
|
查询
|
||||||
|
</Button>
|
||||||
<Form.Section>
|
<Form.Section>
|
||||||
<Button
|
|
||||||
label='查询'
|
|
||||||
type='primary'
|
|
||||||
htmlType='submit'
|
|
||||||
className='btn-margin-right'
|
|
||||||
onClick={refresh}
|
|
||||||
loading={loading}
|
|
||||||
>
|
|
||||||
查询
|
|
||||||
</Button>
|
|
||||||
</Form.Section>
|
</Form.Section>
|
||||||
</>
|
</>
|
||||||
</Form>
|
</Form>
|
||||||
|
|||||||
@@ -90,6 +90,12 @@ function renderType(type) {
|
|||||||
图混合
|
图混合
|
||||||
</Tag>
|
</Tag>
|
||||||
);
|
);
|
||||||
|
case 'UPLOAD':
|
||||||
|
return (
|
||||||
|
<Tag color='blue' size='large'>
|
||||||
|
上传文件
|
||||||
|
</Tag>
|
||||||
|
);
|
||||||
case 'SHORTEN':
|
case 'SHORTEN':
|
||||||
return (
|
return (
|
||||||
<Tag color='pink' size='large'>
|
<Tag color='pink' size='large'>
|
||||||
@@ -239,7 +245,7 @@ const renderTimestamp = (timestampInSeconds) => {
|
|||||||
// 修改renderDuration函数以包含颜色逻辑
|
// 修改renderDuration函数以包含颜色逻辑
|
||||||
function renderDuration(submit_time, finishTime) {
|
function renderDuration(submit_time, finishTime) {
|
||||||
// 确保startTime和finishTime都是有效的时间戳
|
// 确保startTime和finishTime都是有效的时间戳
|
||||||
if (!submit_time || !finishTime) return 'N/A';
|
if (!submit_time || !finishTime) return 'N/A';
|
||||||
|
|
||||||
// 将时间戳转换为Date对象
|
// 将时间戳转换为Date对象
|
||||||
const start = new Date(submit_time);
|
const start = new Date(submit_time);
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import { getUserIdFromLocalStorage, showError } from './utils';
|
import { getUserIdFromLocalStorage, showError } from './utils';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
|
|
||||||
export const API = axios.create({
|
export let API = axios.create({
|
||||||
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
|
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
|
||||||
? import.meta.env.VITE_REACT_APP_SERVER_URL
|
? import.meta.env.VITE_REACT_APP_SERVER_URL
|
||||||
: '',
|
: '',
|
||||||
@@ -10,6 +10,17 @@ export const API = axios.create({
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export function updateAPI() {
|
||||||
|
API = axios.create({
|
||||||
|
baseURL: import.meta.env.VITE_REACT_APP_SERVER_URL
|
||||||
|
? import.meta.env.VITE_REACT_APP_SERVER_URL
|
||||||
|
: '',
|
||||||
|
headers: {
|
||||||
|
'New-API-User': getUserIdFromLocalStorage()
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
API.interceptors.response.use(
|
API.interceptors.response.use(
|
||||||
(response) => response,
|
(response) => response,
|
||||||
(error) => {
|
(error) => {
|
||||||
|
|||||||
@@ -140,6 +140,12 @@ export function removeTrailingSlash(url) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getTodayStartTimestamp() {
|
||||||
|
var now = new Date();
|
||||||
|
now.setHours(0, 0, 0, 0);
|
||||||
|
return Math.floor(now.getTime() / 1000);
|
||||||
|
}
|
||||||
|
|
||||||
export function timestamp2string(timestamp) {
|
export function timestamp2string(timestamp) {
|
||||||
let date = new Date(timestamp * 1000);
|
let date = new Date(timestamp * 1000);
|
||||||
let year = date.getFullYear().toString();
|
let year = date.getFullYear().toString();
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import TokensTable from '../../components/TokensTable';
|
import TokensTable from '../../components/TokensTable';
|
||||||
import { Layout } from '@douyinfe/semi-ui';
|
import { Banner, Layout } from '@douyinfe/semi-ui';
|
||||||
const Token = () => (
|
const Token = () => (
|
||||||
<>
|
<>
|
||||||
<Layout>
|
<Layout>
|
||||||
<Layout.Header>
|
<Layout.Header>
|
||||||
<h3>我的令牌</h3>
|
<Banner
|
||||||
|
type='warning'
|
||||||
|
description='令牌无法精确控制使用额度,请勿直接将令牌分发给用户。'
|
||||||
|
/>
|
||||||
</Layout.Header>
|
</Layout.Header>
|
||||||
<Layout.Content>
|
<Layout.Content>
|
||||||
<TokensTable />
|
<TokensTable />
|
||||||
|
|||||||
Reference in New Issue
Block a user