From 2100d8ee0caf3b26ae887a3aed3ec6f5c9fcbb71 Mon Sep 17 00:00:00 2001 From: kakingone Date: Wed, 31 Jul 2024 15:48:51 +0800 Subject: [PATCH 01/21] addupload --- common/model-ratio.go | 1 + constant/midjourney.go | 2 ++ dto/midjourney.go | 6 ++++++ relay/constant/relay_mode.go | 4 ++++ relay/relay-mj.go | 8 ++++++-- router/relay-router.go | 1 + service/midjourney.go | 11 ++++++++--- web/src/components/MjLogsTable.js | 8 +++++++- 8 files changed, 35 insertions(+), 6 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 568254b..68b2d65 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -181,6 +181,7 @@ var defaultModelPrice = map[string]float64{ "mj_describe": 0.05, "mj_upscale": 0.05, "swap_face": 0.05, + "mj_upload": 0.05, } var ( diff --git a/constant/midjourney.go b/constant/midjourney.go index 1e9d30a..4fba479 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -27,6 +27,7 @@ const ( MjActionLowVariation = "LOW_VARIATION" MjActionPan = "PAN" MjActionSwapFace = "SWAP_FACE" + MjActionUpload = "UPLOAD" ) var MidjourneyModel2Action = map[string]string{ @@ -45,4 +46,5 @@ var MidjourneyModel2Action = map[string]string{ "mj_low_variation": MjActionLowVariation, "mj_pan": MjActionPan, "swap_face": MjActionSwapFace, + "mj_upload": MjActionUpload, } diff --git a/dto/midjourney.go b/dto/midjourney.go index c675f7e..40251ee 100644 --- a/dto/midjourney.go +++ b/dto/midjourney.go @@ -33,6 +33,12 @@ type MidjourneyResponse struct { Result string `json:"result"` } +type MidjourneyUploadResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Result []string `json:"result"` +} + type MidjourneyResponseWithStatusCode struct { StatusCode int `json:"statusCode"` Response MidjourneyResponse diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index a072c74..6006bc6 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -27,6 +27,7 @@ const ( RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + RelayModeMidjourneyUpload RelayModeAudioSpeech // tts RelayModeAudioTranscription // whisper @@ -81,6 +82,9 @@ func Path2RelayModeMidjourney(path string) int { } else if strings.HasSuffix(path, "/mj/insight-face/swap") { // midjourney plus relayMode = RelayModeSwapFace + } else if strings.HasSuffix(path, "/submit/upload-discord-images") { + // midjourney plus + relayMode = RelayModeMidjourneyUpload } else if strings.HasSuffix(path, "/mj/submit/imagine") { relayMode = RelayModeMidjourneyImagine } else if strings.HasSuffix(path, "/mj/submit/blend") { diff --git a/relay/relay-mj.go b/relay/relay-mj.go index b399061..73ea468 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -382,6 +382,8 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons midjRequest.Action = constant.MjActionShorten } else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复 midjRequest.Action = constant.MjActionBlend + } else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复 + midjRequest.Action = constant.MjActionUpload } else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果 mjId := "" if relayMode == relayconstant.RelayModeMidjourneyChange { @@ -580,7 +582,10 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons responseBody = []byte(newBody) } } - + if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" { + midjourneyTask.Progress = "100%" + midjourneyTask.Status = "SUCCESS" + } err = midjourneyTask.Insert() if err != nil { return &dto.MidjourneyResponse{ @@ -594,7 +599,6 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1) responseBody = []byte(newBody) } - //resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) bodyReader := io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/router/relay-router.go b/router/relay-router.go index 2bf2ca2..b8d4a43 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -79,5 +79,6 @@ func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) { relayMjRouter.GET("/task/:id/image-seed", controller.RelayMidjourney) relayMjRouter.POST("/task/list-by-condition", controller.RelayMidjourney) relayMjRouter.POST("/insight-face/swap", controller.RelayMidjourney) + relayMjRouter.POST("/submit/upload-discord-images", controller.RelayMidjourney) } } diff --git a/service/midjourney.go b/service/midjourney.go index 6bb3a9e..d7b2a9a 100644 --- a/service/midjourney.go +++ b/service/midjourney.go @@ -49,6 +49,8 @@ func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (strin action = constant.MjActionModal case relayconstant.RelayModeSwapFace: action = constant.MjActionSwapFace + case relayconstant.RelayModeMidjourneyUpload: + action = constant.MjActionUpload case relayconstant.RelayModeMidjourneySimpleChange: params := ConvertSimpleChangeParams(midjRequest.Content) if params == nil { @@ -220,7 +222,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err } var midjResponse dto.MidjourneyResponse - + var midjourneyUploadsResponse dto.MidjourneyUploadResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err @@ -230,13 +232,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_response_body_failed", statusCode), responseBody, err } respStr := string(responseBody) - log.Printf("responseBody: %s", respStr) + log.Printf("respStr: %s", respStr) if respStr == "" { return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil } else { err = json.Unmarshal(responseBody, &midjResponse) if err != nil { - return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse) + if err2 != nil { + return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err + } } } //log.Printf("midjResponse: %v", midjResponse) diff --git a/web/src/components/MjLogsTable.js b/web/src/components/MjLogsTable.js index e86ee31..d6aac98 100644 --- a/web/src/components/MjLogsTable.js +++ b/web/src/components/MjLogsTable.js @@ -90,6 +90,12 @@ function renderType(type) { 图混合 ); + case 'UPLOAD': + return ( + + 上传文件 + + ); case 'SHORTEN': return ( @@ -239,7 +245,7 @@ const renderTimestamp = (timestampInSeconds) => { // 修改renderDuration函数以包含颜色逻辑 function renderDuration(submit_time, finishTime) { // 确保startTime和finishTime都是有效的时间戳 - if (!submit_time || !finishTime) return 'N/A'; + if (!submit_time || !finishTime) return 'N/A'; // 将时间戳转换为Date对象 const start = new Date(submit_time); From c92ab3b569ae6e61014704983e8b0b1775051a72 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 16:13:08 +0800 Subject: [PATCH 02/21] =?UTF-8?q?feat:=20=E6=97=A5=E5=BF=97=E6=96=B0?= =?UTF-8?q?=E5=A2=9Erpm=E5=92=8Ctpm=E6=95=B0=E6=8D=AE=E3=80=82(close=20#38?= =?UTF-8?q?4)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/log.go | 23 +++++++++++++-- relay/relay-text.go | 9 +++++- web/src/components/LogsTable.js | 51 ++++++++++++++++++--------------- 3 files changed, 57 insertions(+), 26 deletions(-) diff --git a/model/log.go b/model/log.go index da0d76f..aa0650f 100644 --- a/model/log.go +++ b/model/log.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "one-api/common" "strings" + "time" ) type Log struct { @@ -172,12 +173,18 @@ type Stat struct { } func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (stat Stat) { - tx := DB.Table("logs").Select("sum(quota) quota, count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") + tx := DB.Table("logs").Select("sum(quota) quota") + + // 为rpm和tpm创建单独的查询 + rpmTpmQuery := DB.Table("logs").Select("count(*) rpm, sum(prompt_tokens) + sum(completion_tokens) tpm") + if username != "" { tx = tx.Where("username = ?", username) + rpmTpmQuery = rpmTpmQuery.Where("username = ?", username) } if tokenName != "" { tx = tx.Where("token_name = ?", tokenName) + rpmTpmQuery = rpmTpmQuery.Where("token_name = ?", tokenName) } if startTimestamp != 0 { tx = tx.Where("created_at >= ?", startTimestamp) @@ -187,11 +194,23 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa } if modelName != "" { tx = tx.Where("model_name = ?", modelName) + rpmTpmQuery = rpmTpmQuery.Where("model_name = ?", modelName) } if channel != 0 { tx = tx.Where("channel_id = ?", channel) + rpmTpmQuery = rpmTpmQuery.Where("channel_id = ?", channel) } - tx.Where("type = ?", LogTypeConsume).Scan(&stat) + + tx = tx.Where("type = ?", LogTypeConsume) + rpmTpmQuery = rpmTpmQuery.Where("type = ?", LogTypeConsume) + + // 只统计最近60秒的rpm和tpm + rpmTpmQuery = rpmTpmQuery.Where("created_at >= ?", time.Now().Add(-60*time.Second).Unix()) + + // 执行查询 + tx.Scan(&stat) + rpmTpmQuery.Scan(&stat) + return stat } diff --git a/relay/relay-text.go b/relay/relay-text.go index b704493..636be56 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -286,7 +286,14 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64, usePrice bool, extraContent string) { - + if usage == nil { + usage = &dto.Usage{ + PromptTokens: relayInfo.PromptTokens, + CompletionTokens: 0, + TotalTokens: relayInfo.PromptTokens, + } + extraContent += " ,(可能是请求出错)" + } useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 55106f2..03a7205 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -475,6 +475,9 @@ const LogsTable = () => { }; const handleEyeClick = async () => { + if (loadingStat) { + return; + } setLoadingStat(true); if (isAdminUser) { await getLogStat(); @@ -596,6 +599,7 @@ const LogsTable = () => { .catch((reason) => { showError(reason); }); + handleEyeClick(); }, []); const searchLogs = async () => { @@ -622,19 +626,17 @@ const LogsTable = () => {
-

- 使用明细(总消耗额度: - - {showStat ? renderQuota(stat.quota) : '点击查看'} - - ) -

+ + + 总消耗额度: {renderQuota(stat.quota)} + + + RPM: {stat.rpm} + + + TPM: {stat.tpm} + +
@@ -700,17 +702,19 @@ const LogsTable = () => { /> )} + - +
@@ -736,6 +740,7 @@ const LogsTable = () => { onChange={(value) => { setLogType(parseInt(value)); loadLogs(0, pageSize, parseInt(value)); + handleEyeClick(); }} > 全部 From 3b1745c712524fdb67459a46d15ee7bc25f97975 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 16:33:59 +0800 Subject: [PATCH 03/21] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E6=9F=A5=E8=AF=A2=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/log.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/log.go b/model/log.go index aa0650f..d2e5b84 100644 --- a/model/log.go +++ b/model/log.go @@ -103,7 +103,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName tx = DB.Where("type = ?", logType) } if modelName != "" { - tx = tx.Where("model_name like ?", "%"+modelName+"%") + tx = tx.Where("model_name like ?", modelName) } if username != "" { tx = tx.Where("username = ?", username) @@ -132,7 +132,7 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int tx = DB.Where("user_id = ? and type = ?", userId, logType) } if modelName != "" { - tx = tx.Where("model_name = ?", modelName) + tx = tx.Where("model_name like ?", modelName) } if tokenName != "" { tx = tx.Where("token_name = ?", tokenName) From 54f6e660f1f565e68e38563cf43862bf810de5cb Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 17:36:26 +0800 Subject: [PATCH 04/21] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E6=9F=A5=E5=88=9D=E5=A7=8B=E6=97=B6=E9=97=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/LogsTable.js | 8 ++++---- web/src/helpers/utils.js | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 03a7205..b19adbb 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -1,11 +1,11 @@ import React, { useEffect, useState } from 'react'; import { API, - copy, + copy, getTodayStartTimestamp, isAdmin, showError, showSuccess, - timestamp2string, + timestamp2string } from '../helpers'; import { @@ -419,12 +419,12 @@ const LogsTable = () => { const [logType, setLogType] = useState(0); const isAdminUser = isAdmin(); let now = new Date(); - // 初始化start_timestamp为前一天 + // 初始化start_timestamp为今天0点 const [inputs, setInputs] = useState({ username: '', token_name: '', model_name: '', - start_timestamp: timestamp2string(now.getTime() / 1000 - 86400), + start_timestamp: timestamp2string(getTodayStartTimestamp()), end_timestamp: timestamp2string(now.getTime() / 1000 + 3600), channel: '', }); diff --git a/web/src/helpers/utils.js b/web/src/helpers/utils.js index 321b00a..579c22d 100644 --- a/web/src/helpers/utils.js +++ b/web/src/helpers/utils.js @@ -140,6 +140,12 @@ export function removeTrailingSlash(url) { } } +export function getTodayStartTimestamp() { + var now = new Date(); + now.setHours(0, 0, 0, 0); + return Math.floor(now.getTime() / 1000); +} + export function timestamp2string(timestamp) { let date = new Date(timestamp * 1000); let year = date.getFullYear().toString(); From 58b4c237a4e6dfdccd6240f0b11ad56f6e81e910 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 17:39:18 +0800 Subject: [PATCH 05/21] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96rpm=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/LogsTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index b19adbb..b725cd2 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -577,6 +577,7 @@ const LogsTable = () => { const refresh = async () => { // setLoading(true); setActivePage(1); + handleEyeClick(); await loadLogs(0, pageSize, logType); }; @@ -740,7 +741,6 @@ const LogsTable = () => { onChange={(value) => { setLogType(parseInt(value)); loadLogs(0, pageSize, parseInt(value)); - handleEyeClick(); }} > 全部 From b7690fe17d4b11c2818b4bade4590dc2a273153b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 18:06:25 +0800 Subject: [PATCH 06/21] =?UTF-8?q?fix:=20=E6=97=A5=E5=BF=97=E6=A8=A1?= =?UTF-8?q?=E7=B3=8A=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/LogsTable.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index b725cd2..c67dd16 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -534,6 +534,7 @@ const LogsTable = () => { } else { url = `/api/log/self/?p=${startIdx}&page_size=${pageSize}&type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; } + url = encodeURI(url); const res = await API.get(url); const { success, message, data } = res.data; if (success) { From f8f15bd1d09bd3ca2d1614bf09713cf14ec6946b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 1 Aug 2024 18:14:10 +0800 Subject: [PATCH 07/21] =?UTF-8?q?fix:=20rpm=E6=A8=A1=E7=B3=8A=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/LogsTable.js | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index c67dd16..2557c4b 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -449,8 +449,10 @@ const LogsTable = () => { const getLogSelfStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let url = `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`; + url = encodeURI(url); let res = await API.get( - `/api/log/self/stat?type=${logType}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}`, + url, ); const { success, message, data } = res.data; if (success) { @@ -463,8 +465,10 @@ const LogsTable = () => { const getLogStat = async () => { let localStartTimestamp = Date.parse(start_timestamp) / 1000; let localEndTimestamp = Date.parse(end_timestamp) / 1000; + let url = `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`; + url = encodeURI(url); let res = await API.get( - `/api/log/stat?type=${logType}&username=${username}&token_name=${token_name}&model_name=${model_name}&start_timestamp=${localStartTimestamp}&end_timestamp=${localEndTimestamp}&channel=${channel}`, + url, ); const { success, message, data } = res.data; if (success) { From 22a98c58796ccfc7b4fac449d25cb032e8494a97 Mon Sep 17 00:00:00 2001 From: HowieWu <98788152+utopeadia@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:20:26 +0800 Subject: [PATCH 08/21] =?UTF-8?q?=E4=BF=AE=E6=94=B9Gemini=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用GEMINI_MODEL_API环境变量覆盖默认版本映射,使用","分隔不同模型和版本 -e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta,gemini-1.5-pro:v1beta,gemini-1.5-flash-latest:v1beta,gemini-1.5-flash-001:v1beta,gemini-1.5-flash:v1beta,gemini-ultra:v1beta,gemini-1.5-pro-exp-0801:v1beta" --- relay/channel/gemini/adaptor.go | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index e132d2f..35f236a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -6,12 +6,15 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" + "os" "one-api/dto" "one-api/relay/channel" + "strings" relaycommon "one-api/relay/common" ) type Adaptor struct { + modelVersionMap map[string]string } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -25,18 +28,32 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { -} - -// 定义一个映射,存储模型名称和对应的版本 -var modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-ultra": "v1beta", + modelVersionMapStr := os.Getenv("GEMINI_MODEL_API") + if modelVersionMapStr == "" { + a.modelVersionMap = map[string]string{ + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", + } + return + } + a.modelVersionMap = make(map[string]string) + for _, pair := range strings.Split(modelVersionMapStr, ",") { + parts := strings.Split(pair, ":") + if len(parts) == 2 { + a.modelVersionMap[parts[0]] = parts[1] + } + } } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := modelVersionMap[info.UpstreamModelName] + version, beta := a.modelVersionMap[info.UpstreamModelName] if !beta { if info.ApiVersion != "" { version = info.ApiVersion From fc0db4505c7a6274e7b5050c53aea33f4c9f3b36 Mon Sep 17 00:00:00 2001 From: HowieWu <98788152+utopeadia@users.noreply.github.com> Date: Fri, 2 Aug 2024 11:25:41 +0800 Subject: [PATCH 09/21] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加Gemini版本变量说明 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 30062a7..57ab54e 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 +- `GEMINI_MODEL_API`:Gemini模型指定版本(v1/v1beta),如果配置会覆盖默认配置需要完整给出全部v1beta模型,使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置:gemini-1.5-pro-latest,gemini-1.5-pro-001,gemini-1.5-pro,gemini-1.5-pro-exp-0801,gemini-1.5-flash-latest,gemini-1.5-flash-001,gemini-1.5-flash,gemini-ultra模型为v1beta,其他为v1。 ## 部署 ### 部署要求 From e504665f68d298ff7e88fbffa570eefeb6997772 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 2 Aug 2024 17:23:59 +0800 Subject: [PATCH 10/21] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96Gemini=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=89=88=E6=9C=AC=E8=8E=B7=E5=8F=96=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- constant/env.go | 29 +++++++++++++++++++++++++++++ main.go | 2 ++ relay/channel/gemini/adaptor.go | 28 +++------------------------- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 57ab54e..971871a 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 -- `GEMINI_MODEL_API`:Gemini模型指定版本(v1/v1beta),如果配置会覆盖默认配置需要完整给出全部v1beta模型,使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置:gemini-1.5-pro-latest,gemini-1.5-pro-001,gemini-1.5-pro,gemini-1.5-pro-exp-0801,gemini-1.5-flash-latest,gemini-1.5-flash-001,gemini-1.5-flash,gemini-ultra模型为v1beta,其他为v1。 +- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 ## 部署 ### 部署要求 diff --git a/constant/env.go b/constant/env.go index 76146ca..dd3ae65 100644 --- a/constant/env.go +++ b/constant/env.go @@ -1,7 +1,10 @@ package constant import ( + "fmt" "one-api/common" + "os" + "strings" ) var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) @@ -15,3 +18,29 @@ var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) + +var GeminiModelMap = map[string]string{ + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", +} + +func InitEnv() { + modelVersionMapStr := strings.TrimSpace(os.Getenv("GEMINI_MODEL_MAP")) + if modelVersionMapStr == "" { + return + } + for _, pair := range strings.Split(modelVersionMapStr, ",") { + parts := strings.Split(pair, ":") + if len(parts) == 2 { + GeminiModelMap[parts[0]] = parts[1] + } else { + common.SysError(fmt.Sprintf("invalid model version map: %s", pair)) + } + } +} diff --git a/main.go b/main.go index 959b795..efee772 100644 --- a/main.go +++ b/main.go @@ -55,6 +55,8 @@ func main() { common.FatalLog("failed to initialize Redis: " + err.Error()) } + // Initialize constants + constant.InitEnv() // Initialize options model.InitOptionMap() if common.RedisEnabled { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 35f236a..4c4649f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -6,15 +6,13 @@ import ( "github.com/gin-gonic/gin" "io" "net/http" - "os" + "one-api/constant" "one-api/dto" "one-api/relay/channel" - "strings" relaycommon "one-api/relay/common" ) type Adaptor struct { - modelVersionMap map[string]string } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -28,32 +26,12 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { - modelVersionMapStr := os.Getenv("GEMINI_MODEL_API") - if modelVersionMapStr == "" { - a.modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-pro-001": "v1beta", - "gemini-1.5-pro": "v1beta", - "gemini-1.5-pro-exp-0801": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-1.5-flash-001": "v1beta", - "gemini-1.5-flash": "v1beta", - "gemini-ultra": "v1beta", - } - return - } - a.modelVersionMap = make(map[string]string) - for _, pair := range strings.Split(modelVersionMapStr, ",") { - parts := strings.Split(pair, ":") - if len(parts) == 2 { - a.modelVersionMap[parts[0]] = parts[1] - } - } + } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := a.modelVersionMap[info.UpstreamModelName] + version, beta := constant.GeminiModelMap[info.UpstreamModelName] if !beta { if info.ApiVersion != "" { version = info.ApiVersion From 88ba8a840ee8d806294b2f20e61a1ed7a514361a Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 01:28:18 +0800 Subject: [PATCH 11/21] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E5=85=85?= =?UTF-8?q?=E5=80=BC=E8=AE=A2=E5=8D=95=E5=8F=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/topup.go | 4 ++-- controller/user.go | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index 87c68c3..dc1e265 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -101,8 +101,8 @@ func RequestEpay(c *gin.Context) { } uri, params, err := client.Purchase(&epay.PurchaseArgs{ Type: payType, - ServiceTradeNo: "A" + tradeNo, - Name: "B" + tradeNo, + ServiceTradeNo: fmt.Sprintf("USR%d-%s", id, tradeNo), + Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Device: epay.PC, NotifyUrl: notifyUrl, diff --git a/controller/user.go b/controller/user.go index a6798eb..6faec2b 100644 --- a/controller/user.go +++ b/controller/user.go @@ -791,11 +791,11 @@ type topUpRequest struct { Key string `json:"key"` } -var lock = sync.Mutex{} +var topUpLock = sync.Mutex{} func TopUp(c *gin.Context) { - lock.Lock() - defer lock.Unlock() + topUpLock.Lock() + defer topUpLock.Unlock() req := topUpRequest{} err := c.ShouldBindJSON(&req) if err != nil { From 8a9ff36fbf2d78026f2afddff1f12f46e6e09796 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 16:55:29 +0800 Subject: [PATCH 12/21] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96relay=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 75 +++++++++++++++++++++++++-------------- middleware/distributor.go | 1 - 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 0c79015..66339f4 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -39,38 +39,28 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode func Relay(c *gin.Context) { relayMode := constant.Path2RelayMode(c.Request.URL.Path) - retryTimes := common.RetryTimes requestId := c.GetString(common.RequestIdKey) - channelId := c.GetInt("channel_id") - channelType := c.GetInt("channel_type") - channelName := c.GetString("channel_name") group := c.GetString("group") originalModel := c.GetString("original_model") - openaiErr := relayHandler(c, relayMode) - c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) - if openaiErr != nil { - go processChannelError(c, channelId, channelType, channelName, openaiErr) - } else { - retryTimes = 0 - } - for i := 0; shouldRetry(c, channelId, openaiErr, retryTimes) && i < retryTimes; i++ { - channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, fmt.Sprintf("Failed to get channel: %s", err.Error())) break } - channelId = channel.Id - useChannel := c.GetStringSlice("use_channel") - useChannel = append(useChannel, fmt.Sprintf("%d", channel.Id)) - c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) - middleware.SetupContextForSelectedChannel(c, channel, originalModel) - requestBody, err := common.GetRequestBody(c) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - openaiErr = relayHandler(c, relayMode) - if openaiErr != nil { - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break } } useChannel := c.GetStringSlice("use_channel") @@ -90,7 +80,36 @@ func Relay(c *gin.Context) { } } -func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { +func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relayHandler(c, relayMode) +} + +func addUsedChannel(c *gin.Context, channelId int) { + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) +} + +func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { + if retryCount == 0 { + return &model.Channel{ + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + }, nil + } + channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) + if err != nil { + return nil, err + } + middleware.SetupContextForSelectedChannel(c, channel, originalModel) + return channel, nil +} + +func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retryTimes int) bool { if openaiErr == nil { return false } @@ -114,6 +133,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt return true } if openaiErr.StatusCode == http.StatusBadRequest { + channelType := c.GetInt("channel_type") + if channelType == common.ChannelTypeAnthropic { + return true + } return false } if openaiErr.StatusCode == 408 { diff --git a/middleware/distributor.go b/middleware/distributor.go index f150b41..fad868d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -184,7 +184,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel == nil { return } - c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) From fbe6cd75b1531311dc1438fb78a64ffe8b56e083 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 17:07:14 +0800 Subject: [PATCH 13/21] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96relay=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 66339f4..4726d7e 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -47,8 +47,14 @@ func Relay(c *gin.Context) { for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - common.LogError(c, fmt.Sprintf("Failed to get channel: %s", err.Error())) - break + errMsg := fmt.Sprintf("获取渠道出错: %s", err.Error()) + common.LogError(c, errMsg) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + openaiErr.Error.Message = common.MessageWithRequestId(errMsg, requestId) + c.JSON(openaiErr.StatusCode, gin.H{ + "error": openaiErr.Error, + }) + return } openaiErr = relayRequest(c, relayMode, channel) From dd12a0052f1f9c8110230aa4b8567ea345a10fc4 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 17:12:16 +0800 Subject: [PATCH 14/21] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96relay=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 4726d7e..30217f0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -47,14 +48,9 @@ func Relay(c *gin.Context) { for i := 0; i <= common.RetryTimes; i++ { channel, err := getChannel(c, group, originalModel, i) if err != nil { - errMsg := fmt.Sprintf("获取渠道出错: %s", err.Error()) - common.LogError(c, errMsg) + common.LogError(c, err.Error()) openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) - openaiErr.Error.Message = common.MessageWithRequestId(errMsg, requestId) - c.JSON(openaiErr.StatusCode, gin.H{ - "error": openaiErr.Error, - }) - return + break } openaiErr = relayRequest(c, relayMode, channel) @@ -72,7 +68,7 @@ func Relay(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if openaiErr != nil { @@ -109,7 +105,7 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m } channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) if err != nil { - return nil, err + return nil, errors.New(fmt.Sprintf("获取重试渠道失败: %s", err.Error())) } middleware.SetupContextForSelectedChannel(c, channel, originalModel) return channel, nil From afd328efcfcdfb44efa1f70c42f30a93c8fae1b9 Mon Sep 17 00:00:00 2001 From: HowieWu <98788152+utopeadia@users.noreply.github.com> Date: Sat, 3 Aug 2024 17:19:44 +0800 Subject: [PATCH 15/21] =?UTF-8?q?=E4=BF=AE=E6=94=B9readme=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 971871a..b561cda 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用。 - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true`。 - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度。 -- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用模型:版本指定,","分隔,例如:-e GEMINI_MODEL_API="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 +- `GEMINI_MODEL_MAP`:Gemini模型指定版本(v1/v1beta),使用“模型:版本”指定,","分隔,例如:-e GEMINI_MODEL_MAP="gemini-1.5-pro-latest:v1beta,gemini-1.5-pro-001:v1beta",为空则使用默认配置 ## 部署 ### 部署要求 From 5acf0745412caf80b05d97477e080df62cbf0a89 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 17:32:28 +0800 Subject: [PATCH 16/21] =?UTF-8?q?chore:=20=E4=BC=98=E5=8C=96=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E7=A6=81=E7=94=A8=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 2 +- controller/relay.go | 19 +++++++++++++------ middleware/distributor.go | 7 +------ model/channel.go | 7 +++++++ relay/relay-mj.go | 2 +- 5 files changed, 23 insertions(+), 14 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index fe27978..95c4a60 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -240,7 +240,7 @@ func testAllChannels(notify bool) error { } // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { + if !channel.GetAutoBan() { ban = false } diff --git a/controller/relay.go b/controller/relay.go index 30217f0..3a6fb6f 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -59,7 +59,7 @@ func Relay(c *gin.Context) { return // 成功处理请求,直接返回 } - go processChannelError(c, channel.Id, channel.Type, channel.Name, openaiErr) + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) if !shouldRetry(c, openaiErr, common.RetryTimes-i) { break @@ -97,10 +97,16 @@ func addUsedChannel(c *gin.Context, channelId int) { func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*model.Channel, error) { if retryCount == 0 { + autoBan := c.GetBool("auto_ban") + autoBanInt := 1 + if !autoBan { + autoBanInt = 0 + } return &model.Channel{ - Id: c.GetInt("channel_id"), - Type: c.GetInt("channel_type"), - Name: c.GetString("channel_name"), + Id: c.GetInt("channel_id"), + Type: c.GetInt("channel_type"), + Name: c.GetString("channel_name"), + AutoBan: &autoBanInt, }, nil } channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, retryCount) @@ -154,8 +160,9 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry return true } -func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, err *dto.OpenAIErrorWithStatusCode) { - autoBan := c.GetBool("auto_ban") +func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { + // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 + // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) diff --git a/middleware/distributor.go b/middleware/distributor.go index fad868d..1be3b31 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -187,15 +187,10 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) c.Set("channel_type", channel.Type) - ban := true - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } - c.Set("auto_ban", ban) + c.Set("auto_ban", channel.GetAutoBan()) c.Set("model_mapping", channel.GetModelMapping()) c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/model/channel.go b/model/channel.go index 7db3f07..3f9d9ed 100644 --- a/model/channel.go +++ b/model/channel.go @@ -61,6 +61,13 @@ func (channel *Channel) SetOtherInfo(otherInfo map[string]interface{}) { channel.OtherInfo = string(otherInfoBytes) } +func (channel *Channel) GetAutoBan() bool { + if channel.AutoBan == nil { + return false + } + return *channel.AutoBan == 1 +} + func (channel *Channel) Save() error { return DB.Save(channel).Error } diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 73ea468..4dd81c5 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -549,7 +549,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons if err != nil { common.SysError("get_channel_null: " + err.Error()) } - if channel.AutoBan != nil && *channel.AutoBan == 1 && common.AutomaticDisableChannelEnabled { + if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled { model.UpdateChannelStatusById(midjourneyTask.ChannelId, 2, "No available account instance") } } From 0123ad4d61986f8d4c73b034d030dc01d43bc9e5 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 17:46:13 +0800 Subject: [PATCH 17/21] =?UTF-8?q?fix:=20=E9=87=8D=E8=AF=95=E5=90=8Erequest?= =?UTF-8?q?=20id=E4=B8=8D=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 8 ++++---- relay/relay-text.go | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 3a6fb6f..13fbde0 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -163,7 +163,7 @@ func shouldRetry(c *gin.Context, openaiErr *dto.OpenAIErrorWithStatusCode, retry func processChannelError(c *gin.Context, channelId int, channelType int, channelName string, autoBan bool, err *dto.OpenAIErrorWithStatusCode) { // 不要使用context获取渠道信息,异步处理时可能会出现渠道信息不一致的情况 // do not use context to get channel info, there may be inconsistent channel info when processing asynchronously - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) + common.LogError(c, fmt.Sprintf("relay error (channel #%d, status code: %d): %s", channelId, err.StatusCode, err.Error.Message)) if service.ShouldDisableChannel(channelType, err) && autoBan { service.DisableChannel(channelId, channelName, err.Error.Message) } @@ -240,14 +240,14 @@ func RelayTask(c *gin.Context) { for i := 0; shouldRetryTaskRelay(c, channelId, taskErr, retryTimes) && i < retryTimes; i++ { channel, err := model.CacheGetRandomSatisfiedChannel(group, originalModel, i) if err != nil { - common.LogError(c.Request.Context(), fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) + common.LogError(c, fmt.Sprintf("CacheGetRandomSatisfiedChannel failed: %s", err.Error())) break } channelId = channel.Id useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) c.Set("use_channel", useChannel) - common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) + common.LogInfo(c, fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) @@ -257,7 +257,7 @@ func RelayTask(c *gin.Context) { useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) - common.LogInfo(c.Request.Context(), retryLogStr) + common.LogInfo(c, retryLogStr) } if taskErr != nil { if taskErr.StatusCode == http.StatusTooManyRequests { diff --git a/relay/relay-text.go b/relay/relay-text.go index 636be56..e7c5388 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -253,13 +253,13 @@ func preConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommo if tokenQuota > 100*preConsumedQuota { // 令牌额度充足,信任令牌 preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) + common.LogInfo(c, fmt.Sprintf("user %d quota %d and token %d quota %d are enough, trusted and no need to pre-consume", relayInfo.UserId, userQuota, relayInfo.TokenId, tokenQuota)) } } else { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) + common.LogInfo(c, fmt.Sprintf("user %d with unlimited token has enough quota %d, trusted and no need to pre-consume", relayInfo.UserId, userQuota)) } } if preConsumedQuota > 0 { From 0b4ef42d861380d45ddae569bcb4c8248e92b30b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 3 Aug 2024 22:41:47 +0800 Subject: [PATCH 18/21] fix: channel typ error --- relay/common/relay_info.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 564a7ad..3ed5ee3 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -33,7 +33,7 @@ type RelayInfo struct { } func GenRelayInfo(c *gin.Context) *RelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") @@ -112,7 +112,7 @@ type TaskRelayInfo struct { } func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo { - channelType := c.GetInt("channel") + channelType := c.GetInt("channel_type") channelId := c.GetInt("channel_id") tokenId := c.GetInt("token_id") From 5d0d268c975a68108c3ba1256c29c3dd289ac25d Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 4 Aug 2024 00:17:48 +0800 Subject: [PATCH 19/21] fix: epay --- controller/topup.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index dc1e265..995f412 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -94,6 +94,7 @@ func RequestEpay(c *gin.Context) { returnUrl, _ := url.Parse(constant.ServerAddress + "/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) + tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) client := GetEpayClient() if client == nil { c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) @@ -101,7 +102,7 @@ func RequestEpay(c *gin.Context) { } uri, params, err := client.Purchase(&epay.PurchaseArgs{ Type: payType, - ServiceTradeNo: fmt.Sprintf("USR%d-%s", id, tradeNo), + ServiceTradeNo: tradeNo, Name: fmt.Sprintf("TUC%d", req.Amount), Money: strconv.FormatFloat(payMoney, 'f', 2, 64), Device: epay.PC, @@ -120,7 +121,7 @@ func RequestEpay(c *gin.Context) { UserId: id, Amount: amount, Money: payMoney, - TradeNo: "A" + tradeNo, + TradeNo: tradeNo, CreateTime: time.Now().Unix(), Status: "pending", } From a0a3807bd4e6ea7a43161efe622e059f74d5d09c Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 4 Aug 2024 03:12:24 +0800 Subject: [PATCH 20/21] chore: epay --- controller/topup.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/controller/topup.go b/controller/topup.go index 995f412..c4b1aa9 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -41,12 +41,12 @@ func GetEpayClient() *epay.Client { return withUrl } -func getPayMoney(amount float64, user model.User) float64 { +func getPayMoney(amount float64, group string) float64 { if !common.DisplayInCurrencyEnabled { amount = amount / common.QuotaPerUnit } // 别问为什么用float64,问就是这么点钱没必要 - topupGroupRatio := common.GetTopupGroupRatio(user.Group) + topupGroupRatio := common.GetTopupGroupRatio(group) if topupGroupRatio == 0 { topupGroupRatio = 1 } @@ -75,8 +75,12 @@ func RequestEpay(c *gin.Context) { } id := c.GetInt("id") - user, _ := model.GetUserById(id, false) - payMoney := getPayMoney(float64(req.Amount), *user) + group, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) + return + } + payMoney := getPayMoney(float64(req.Amount), group) if payMoney < 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return @@ -233,8 +237,12 @@ func RequestAmount(c *gin.Context) { return } id := c.GetInt("id") - user, _ := model.GetUserById(id, false) - payMoney := getPayMoney(float64(req.Amount), *user) + group, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(200, gin.H{"message": "error", "data": "获取用户分组失败"}) + return + } + payMoney := getPayMoney(float64(req.Amount), group) if payMoney <= 0.01 { c.JSON(200, gin.H{"message": "error", "data": "充值金额过低"}) return From 67878731fc198e47bbf62fe6cbd06262d94c10d3 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sun, 4 Aug 2024 14:35:16 +0800 Subject: [PATCH 21/21] feat: log user id --- middleware/auth.go | 6 ++++++ model/token.go | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/middleware/auth.go b/middleware/auth.go index edd15de..d2c9b3c 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -143,6 +143,12 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] } token, err := model.ValidateUserToken(key) + if token != nil { + id := c.GetInt("id") + if id == 0 { + c.Set("id", token.Id) + } + } if err != nil { abortWithOpenAiMessage(c, http.StatusUnauthorized, err.Error()) return diff --git a/model/token.go b/model/token.go index 27907af..272c573 100644 --- a/model/token.go +++ b/model/token.go @@ -51,12 +51,12 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status == common.TokenStatusExhausted { keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") + return token, errors.New("该令牌额度已用尽 TokenStatusExhausted[sk-" + keyPrefix + "***" + keySuffix + "]") } else if token.Status == common.TokenStatusExpired { - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if token.Status != common.TokenStatusEnabled { - return nil, errors.New("该令牌状态不可用") + return token, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { if !common.RedisEnabled { @@ -66,7 +66,7 @@ func ValidateUserToken(key string) (token *Token, err error) { common.SysError("failed to update token status" + err.Error()) } } - return nil, errors.New("该令牌已过期") + return token, errors.New("该令牌已过期") } if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { @@ -79,7 +79,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } keyPrefix := key[:3] keySuffix := key[len(key)-3:] - return nil, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) + return token, errors.New(fmt.Sprintf("[sk-%s***%s] 该令牌额度已用尽 !token.UnlimitedQuota && token.RemainQuota = %d", keyPrefix, keySuffix, token.RemainQuota)) } return token, nil }