From 2199cf2304bbf0cfcbb582680fc87399f7147b3e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 12 Nov 2023 23:31:59 +0800 Subject: [PATCH 01/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=8C=89=E9=92=AEbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/SiderBar.js | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/web/src/components/SiderBar.js b/web/src/components/SiderBar.js index fc22b81..95298c9 100644 --- a/web/src/components/SiderBar.js +++ b/web/src/components/SiderBar.js @@ -15,7 +15,7 @@ import { IconLayers, IconSetting, IconCreditCard, - IconSemiLogo, + IconComment, IconHome, IconImage } from '@douyinfe/semi-icons'; @@ -36,7 +36,13 @@ let headerButtons = [ icon: , className: isAdmin()?'semi-navigation-item-normal':'tableHiddle', }, - + { + text: '聊天', + itemKey: 'chat', + to: '/chat', + icon: , + className: localStorage.getItem('chat_link')?'semi-navigation-item-normal':'tableHiddle', + }, { text: '令牌', itemKey: 'token', @@ -89,14 +95,6 @@ let headerButtons = [ // } ]; -if (localStorage.getItem('chat_link')) { - headerButtons.splice(1, 0, { - name: '聊天', - to: '/chat', - icon: 'comments' - }); -} - const HeaderBar = () => { const [userState, userDispatch] = useContext(UserContext); let navigate = useNavigate(); @@ -134,6 +132,7 @@ const HeaderBar = () => { midjourney: "/midjourney", setting: "/setting", about: "/about", + chat: "/chat", }; return ( Date: Sun, 12 Nov 2023 23:32:22 +0800 Subject: [PATCH 02/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0mj=E6=B8=A0=E9=81=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 1 + controller/model.go | 9 +++++++++ docker-compose.yml | 2 +- web/src/constants/channel.constants.js | 1 + 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index bb2adc7..f1cc07d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -14,6 +14,7 @@ import ( // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ + "midjourney": 50, "gpt-4": 15, "gpt-4-0314": 15, "gpt-4-0613": 15, diff --git a/controller/model.go b/controller/model.go index f990433..201d643 100644 --- a/controller/model.go +++ b/controller/model.go @@ -54,6 +54,15 @@ func init() { }) // https://platform.openai.com/docs/models/model-endpoint-compatibility openAIModels = []OpenAIModels{ + { + Id: "midjourney", + Object: "model", + Created: 1677649963, + OwnedBy: "Midjourney", + Permission: permission, + Root: "midjourney", + Parent: nil, + }, { Id: "dall-e-2", Object: "model", diff --git a/docker-compose.yml b/docker-compose.yml index 9b814a0..6c5350d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.4' services: one-api: - image: justsong/one-api:latest + image: calciumion/neko-api:main container_name: one-api restart: always command: --log-dir /app/logs diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 7640774..6da8daf 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -1,5 +1,6 @@ export const CHANNEL_OPTIONS = [ { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 99, text: 'Midjourney-Proxy', value: 99, color: 'green' }, { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, From b4bd9a19d9313364f6ca0bfc740ce63e9c124cf5 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Sun, 12 Nov 2023 23:33:27 +0800 Subject: [PATCH 03/15] add docker-image-amd64.yml --- .github/workflows/docker-image-amd64.yml | 54 +++++++----------------- 1 file changed, 16 insertions(+), 38 deletions(-) diff --git a/.github/workflows/docker-image-amd64.yml b/.github/workflows/docker-image-amd64.yml index e3b8439..1ab220c 100644 --- a/.github/workflows/docker-image-amd64.yml +++ b/.github/workflows/docker-image-amd64.yml @@ -1,54 +1,32 @@ -name: Publish Docker image (amd64) +name: Docker Image CI on: push: - tags: - - '*' - workflow_dispatch: - inputs: - name: - description: 'reason' - required: false + branches: [ "main" ] + pull_request: + branches: [ "main" ] + jobs: - push_to_registries: - name: Push Docker image to multiple registries + + build: + runs-on: ubuntu-latest - permissions: - packages: write - contents: read + steps: - - name: Check out the repo - uses: actions/checkout@v3 - - - name: Save version info - run: | - git describe --tags > VERSION - - - name: Log in to Docker Hub - uses: docker/login-action@v2 + - uses: actions/checkout@v3 + - uses: docker/login-action@v3.0.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Log in to the Container registry - uses: docker/login-action@v2 - with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Extract metadata (tags, labels) for Docker id: meta - uses: docker/metadata-action@v4 + uses: docker/metadata-action@v3 with: - images: | - justsong/one-api - ghcr.io/${{ github.repository }} - - - name: Build and push Docker images - uses: docker/build-push-action@v3 + images: calciumion/neko-api + - name: Build the Docker image + uses: docker/build-push-action@v5.0.0 with: context: . push: true tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} \ No newline at end of file + labels: ${{ steps.meta.outputs.labels }} From a7d7fb789180fcae06a3decd81ff9a24ea05b7ab Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:17:29 +0800 Subject: [PATCH 04/15] Update README.md --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 89dfba3..2858709 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ 5. 渠道显示已使用额度,支持指定组织访问 6. 分页支持选择每页显示数量 +## 交流群 +![IMG_2847](https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266) + + ## 界面截图 ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605) ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/f4f40ed4-8ccb-43d7-a580-90677827646d) From 15b5db66dac362b9ae88d7fe3bf51d2f77e5a585 Mon Sep 17 00:00:00 2001 From: Calcium-Ion <61247483+Calcium-Ion@users.noreply.github.com> Date: Wed, 15 Nov 2023 13:23:09 +0800 Subject: [PATCH 05/15] Update README.md --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 2858709..dd8d639 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,7 @@ 6. 分页支持选择每页显示数量 ## 交流群 -![IMG_2847](https://github.com/Calcium-Ion/new-api/assets/61247483/de536a8a-0161-47a7-a0a2-66ef6de81266) - + ## 界面截图 ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605) From 16ad764f9b7162b9032435b2e6163f7157281b52 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 15 Nov 2023 18:27:13 +0800 Subject: [PATCH 06/15] try to fix email --- controller/relay-audio.go | 2 +- controller/relay-image.go | 2 +- controller/relay-mj.go | 2 +- controller/relay-text.go | 4 ++-- model/token.go | 26 ++++++++++++++++---------- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/controller/relay-audio.go b/controller/relay-audio.go index fe91dbc..13d9c9f 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -99,7 +99,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode go func() { quota := countTokenText(audioResponse.Text, audioModel) quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota) + err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/controller/relay-image.go b/controller/relay-image.go index 5cebcdb..8c16ec1 100644 --- a/controller/relay-image.go +++ b/controller/relay-image.go @@ -147,7 +147,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode var textResponse ImageResponse defer func(ctx context.Context) { if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0) + err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/controller/relay-mj.go b/controller/relay-mj.go index 948c57c..89b0f0c 100644 --- a/controller/relay-mj.go +++ b/controller/relay-mj.go @@ -359,7 +359,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse { defer func(ctx context.Context) { if consumeQuota { - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0) + err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) if err != nil { common.SysError("error consuming token remain quota: " + err.Error()) } diff --git a/controller/relay-text.go b/controller/relay-text.go index 6f56be8..2729650 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -400,7 +400,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { if preConsumedQuota != 0 { go func(ctx context.Context) { // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0) + err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false) if err != nil { common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) } @@ -434,7 +434,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { quota = 0 } quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota) + err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { common.LogError(ctx, "error consuming token remain quota: "+err.Error()) } diff --git a/model/token.go b/model/token.go index 06b9775..5c4bc55 100644 --- a/model/token.go +++ b/model/token.go @@ -5,6 +5,7 @@ import ( "fmt" "gorm.io/gorm" "one-api/common" + "strconv" "strings" ) @@ -194,22 +195,31 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) { return 0, err } if userQuota < quota { - return userQuota, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) + return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) } if !token.UnlimitedQuota { err = DecreaseTokenQuota(tokenId, quota) if err != nil { - return userQuota, err + return 0, err } } err = DecreaseUserQuota(token.UserId, quota) - return userQuota, err + return userQuota - quota, err } -func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int) (err error) { +func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) { token, err := GetTokenById(tokenId) if quota > 0 { + err = DecreaseUserQuota(token.UserId, quota) + } else { + err = IncreaseUserQuota(token.UserId, -quota) + } + if err != nil { + return err + } + + if sendEmail { quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 if quotaTooLow || noMoreQuota { @@ -229,16 +239,12 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo if err != nil { common.SysError("failed to send email" + err.Error()) } + common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota)) } }() } - err = DecreaseUserQuota(token.UserId, quota) - } else { - err = IncreaseUserQuota(token.UserId, -quota) - } - if err != nil { - return err } + if !token.UnlimitedQuota { if quota > 0 { err = DecreaseTokenQuota(tokenId, quota) From 63cd3f05f2a6e4e583a132206eeefb0dc15ec21e Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 15 Nov 2023 21:05:14 +0800 Subject: [PATCH 07/15] support tts --- common/model-ratio.go | 4 ++- common/utils.go | 9 +++++++ controller/relay-audio.go | 56 ++++++++++++++++++++++++++++++++------- controller/relay-utils.go | 9 +++++++ controller/relay.go | 6 +++++ middleware/distributor.go | 9 ++++--- router/relay-router.go | 1 + 7 files changed, 81 insertions(+), 13 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index f1cc07d..820f228 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{ "text-davinci-003": 10, "text-davinci-edit-001": 10, "code-davinci-edit-001": 10, - "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens + "tts-1": 7.5, // 1k characters -> $0.015 + "tts-1-hd": 15, // 1k characters -> $0.03 "davinci": 10, "curie": 10, "babbage": 10, diff --git a/common/utils.go b/common/utils.go index 21bec8f..d65d42a 100644 --- a/common/utils.go +++ b/common/utils.go @@ -207,3 +207,12 @@ func String2Int(str string) int { } return num } + +func StringsContains(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} diff --git a/controller/relay-audio.go b/controller/relay-audio.go index 13d9c9f..e959e73 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -11,10 +11,19 @@ import ( "net/http" "one-api/common" "one-api/model" + "strings" ) +var availableVoices = []string{ + "alloy", + "echo", + "fable", + "onyx", + "nova", + "shimmer", +} + func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - audioModel := "whisper-1" tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode userId := c.GetInt("id") group := c.GetString("group") + var audioRequest AudioRequest + err := common.UnmarshalBodyReusable(c, &audioRequest) + if err != nil { + return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + } + + // request validation + if audioRequest.Model == "" { + return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) + } + + if strings.HasPrefix(audioRequest.Model, "tts-1") { + if audioRequest.Voice == "" { + return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) + } + if !common.StringsContains(availableVoices, audioRequest.Voice) { + return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) + } + } + preConsumedTokens := common.PreConsumedQuota - modelRatio := common.GetModelRatio(audioModel) + modelRatio := common.GetModelRatio(audioRequest.Model) groupRatio := common.GetGroupRatio(group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) @@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } - if modelMap[audioModel] != "" { - audioModel = modelMap[audioModel] + if modelMap[audioRequest.Model] != "" { + audioRequest.Model = modelMap[audioRequest.Model] } } @@ -97,7 +126,12 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode defer func(ctx context.Context) { go func() { - quota := countTokenText(audioResponse.Text, audioModel) + var quota int + if strings.HasPrefix(audioRequest.Model, "tts-1") { + quota = countAudioToken(audioRequest.Input, audioRequest.Model) + } else { + quota = countAudioToken(audioResponse.Text, audioRequest.Model) + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { @@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId) + model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) @@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if err != nil { return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + if strings.HasPrefix(audioRequest.Model, "tts-1") { + + } else { + err = json.Unmarshal(responseBody, &audioResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + } } resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index d2f3d2f..40aa547 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -10,6 +10,7 @@ import ( "one-api/common" "strconv" "strings" + "unicode/utf8" ) var stopFinishReason = "stop" @@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int { return 0 } +func countAudioToken(text string, model string) int { + if strings.HasPrefix(model, "tts") { + return utf8.RuneCountInString(text) + } else { + return countTokenText(text, model) + } +} + func countTokenText(text string, model string) int { tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text) diff --git a/controller/relay.go b/controller/relay.go index c505c22..2ca2bc2 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string { return input } +type AudioRequest struct { + Model string `json:"model"` + Voice string `json:"voice"` + Input string `json:"input"` +} + type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` diff --git a/middleware/distributor.go b/middleware/distributor.go index c49a40d..c9d8be8 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) { if modelRequest.Model == "" { modelRequest.Model = "midjourney" } - } else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - err = common.UnmarshalBodyReusable(c, &modelRequest) } + err = common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { abortWithMessage(c, http.StatusBadRequest, "无效的请求") return @@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) { } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = "tts-1" + } else { + modelRequest.Model = "whisper-1" + } } } channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) diff --git a/router/relay-router.go b/router/relay-router.go index c97ea31..3916503 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay) + relayV1Router.POST("/audio/speech", controller.Relay) relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) From 773e48ed6f6623d6867076a982520356c7317310 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 16 Nov 2023 01:44:15 +0800 Subject: [PATCH 08/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=AF=AD=E9=9F=B3?= =?UTF-8?q?=E7=B3=BB=E5=88=97=E6=A8=A1=E5=9E=8B=E8=AE=A1=E8=B4=B9bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- common/model-ratio.go | 2 ++ controller/model.go | 36 ++++++++++++++++++++++++++++++++++++ controller/relay-audio.go | 10 ++++++++-- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 820f228..7d9d6ae 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -39,7 +39,9 @@ var ModelRatio = map[string]float64{ "code-davinci-edit-001": 10, "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "tts-1": 7.5, // 1k characters -> $0.015 + "tts-1-1106": 7.5, // 1k characters -> $0.015 "tts-1-hd": 15, // 1k characters -> $0.03 + "tts-1-hd-1106": 15, // 1k characters -> $0.03 "davinci": 10, "curie": 10, "babbage": 10, diff --git a/controller/model.go b/controller/model.go index 201d643..14bcc00 100644 --- a/controller/model.go +++ b/controller/model.go @@ -90,6 +90,42 @@ func init() { Root: "whisper-1", Parent: nil, }, + { + Id: "tts-1", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1", + Parent: nil, + }, + { + Id: "tts-1-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-1106", + Parent: nil, + }, + { + Id: "tts-1-hd", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd", + Parent: nil, + }, + { + Id: "tts-1-hd-1106", + Object: "model", + Created: 1677649963, + OwnedBy: "openai", + Permission: permission, + Root: "tts-1-hd-1106", + Parent: nil, + }, { Id: "gpt-3.5-turbo", Object: "model", diff --git a/controller/relay-audio.go b/controller/relay-audio.go index e959e73..2e9e49a 100644 --- a/controller/relay-audio.go +++ b/controller/relay-audio.go @@ -126,12 +126,18 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode defer func(ctx context.Context) { go func() { - var quota int + quota := 0 + var promptTokens = 0 if strings.HasPrefix(audioRequest.Model, "tts-1") { quota = countAudioToken(audioRequest.Input, audioRequest.Model) + promptTokens = quota } else { quota = countAudioToken(audioResponse.Text, audioRequest.Model) } + quota = int(float64(quota) * ratio) + if ratio != 0 && quota <= 0 { + quota = 1 + } quotaDelta := quota - preConsumedQuota err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) if err != nil { @@ -144,7 +150,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId) + model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) From 9027ccf615f5490ed9fae3e9c468163b665b7b65 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 16 Nov 2023 01:45:08 +0800 Subject: [PATCH 09/15] =?UTF-8?q?=E5=AE=8C=E5=96=84dall-e-3=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/relay.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index 2ca2bc2..21cbfba 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -91,11 +91,13 @@ type TextRequest struct { } type ImageRequest struct { - Model string `json:"model"` - Quality string `json:"quality"` - Prompt string `json:"prompt"` - N int `json:"n"` - Size string `json:"size"` + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` } type AudioResponse struct { From 51be7f2882d41acf8cfc68a793d6c5ebe1fbf776 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 16:22:13 +0800 Subject: [PATCH 10/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=B6=85=E6=97=B6?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E7=A6=81=E7=94=A8bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index f3850f5..18f1e9b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -180,11 +180,12 @@ func testAllChannels(notify bool) error { err, openaiErr := testChannel(channel, *testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() + + ban := false if milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) - disableChannel(channel.Id, channel.Name, err.Error()) + ban = true } - ban := true // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false From a97bdebd0af78573d9abcb9f44b814db54b1dfaf Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 18:24:37 +0800 Subject: [PATCH 11/15] support gpt-4-1106-vision-preview --- controller/channel-test.go | 7 ++++++- controller/relay-aiproxy.go | 4 ++-- controller/relay-ali.go | 11 ++++++----- controller/relay-baidu.go | 7 ++++--- controller/relay-claude.go | 3 ++- controller/relay-openai.go | 2 +- controller/relay-palm.go | 5 +++-- controller/relay-tencent.go | 7 ++++--- controller/relay-text.go | 6 +++++- controller/relay-utils.go | 31 +++++++++++++++++++++++++------ controller/relay-xunfei.go | 7 ++++--- controller/relay-zhipu.go | 7 ++++--- controller/relay.go | 18 +++++++++++++++--- 13 files changed, 81 insertions(+), 34 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 18f1e9b..c0ac9a6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -86,9 +86,10 @@ func buildTestRequest() *ChatRequest { Model: "", // this will be set later MaxTokens: 1, } + content, _ := json.Marshal("hi") testMessage := Message{ Role: "user", - Content: "hi", + Content: content, } testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest @@ -186,6 +187,10 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) ban = true } + if openaiErr != nil { + err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) + ban = true + } // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce..7dbf679 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = string(request.Messages[len(request.Messages)-1].Content) } return &AIProxyLibraryRequest{ Model: request.Model, @@ -69,7 +69,7 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { } func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { - content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents)) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743..6a79d2b 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: string(message.Content), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = string(message.Content) break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: string(message.Content), + Bot: string(request.Messages[i+1].Content), }) i++ } @@ -184,11 +184,12 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin } func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Output.Text) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Output.Text, + Content: content, }, FinishReason: response.Output.FinishReason, } diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac0..05bbad0 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -109,11 +109,12 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Result) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Result, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7..e131263 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -93,11 +93,12 @@ func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletion } func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { + content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Content: content, Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc0..9b08f85 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + completionTokens += countTokenText(string(choice.Message.Content), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b31..a7b0c1f 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: string(message.Content), } if message.Role == "user" { palmMessage.Author = "0" @@ -76,11 +76,12 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { + content, _ := json.Marshal(candidate.Content) choice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468b..c96e6d4 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: string(message.Content), Role: message.Role, }) } @@ -119,11 +119,12 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { Usage: response.Usage, } if len(response.Choices) > 0 { + content, _ := json.Marshal(response.Choices[0].Messages.Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Choices[0].Messages.Content, + Content: content, }, FinishReason: response.Choices[0].FinishReason, } diff --git a/controller/relay-text.go b/controller/relay-text.go index 2729650..a009267 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -199,9 +199,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } var promptTokens int var completionTokens int + var err error switch relayMode { case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model) + if err != nil { + return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + } case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 40aa547..177d853 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -63,7 +63,8 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func countTokenMessages(messages []Message, model string) int { +func countTokenMessages(messages []Message, model string) (int, error) { + //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -82,15 +83,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) + var arrayContent []MediaMessage + if err := json.Unmarshal(message.Content, &arrayContent); err != nil { + + var stringContent string + if err := json.Unmarshal(message.Content, &stringContent); err != nil { + return 0, err + } else { + tokenNum += getTokenNum(tokenEncoder, stringContent) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + } else { + for _, m := range arrayContent { + if m.Type == "image_url" { + //TODO: getImageToken + tokenNum += 1000 + } else { + tokenNum += getTokenNum(tokenEncoder, m.Text) + } + } } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum + return tokenNum, nil } func countTokenInput(input any, model string) int { diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 91fb604..33383d8 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -112,11 +112,12 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { }, } } + content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Content: content, }, FinishReason: stopFinishReason, } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582..5ad4151 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -144,11 +144,12 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { + content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) openaiChoice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: choice.Role, - Content: strings.Trim(choice.Content, "\""), + Content: content, }, FinishReason: "", } diff --git a/controller/relay.go b/controller/relay.go index 21cbfba..9e910fa 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "fmt" "log" "net/http" @@ -12,9 +13,20 @@ import ( ) type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name *string `json:"name,omitempty"` +} + +type MediaMessage struct { + Type string `json:"type"` + Text string `json:"text"` + ImageUrl MessageImageUrl `json:"image_url,omitempty"` +} + +type MessageImageUrl struct { + Url string `json:"url"` + Detail string `json:"detail"` } const ( From 2d1ca2d9be66f960fc2cb1fdbc35896f9a7218d1 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 20:32:11 +0800 Subject: [PATCH 12/15] fix image token calculate --- controller/relay-utils.go | 74 +++++++++++++++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 ++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 177d853..1873cab 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -2,10 +2,18 @@ package controller import ( "encoding/json" + "errors" "fmt" + "github.com/chai2010/webp" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" + "log" + "math" "net/http" "one-api/common" "strconv" @@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +func getImageToken(imageUrl MessageImageUrl) (int, error) { + if imageUrl.Detail == "low" { + return 85, nil + } + + response, err := http.Get(imageUrl.Url) + if err != nil { + fmt.Println("Error: Failed to get the URL") + return 0, err + } + + defer response.Body.Close() + + // 限制读取的字节数,防止下载整个图片 + limitReader := io.LimitReader(response.Body, 8192) + + // 读取图片的头部信息来获取图片尺寸 + config, _, err := image.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) + config, err = webp.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) + } + } + if config.Width == 0 || config.Height == 0 { + return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error())) + } + if config.Width < 512 && config.Height < 512 { + if imageUrl.Detail == "auto" || imageUrl.Detail == "" { + return 85, nil + } + } + + shortSide := config.Width + otherSide := config.Height + log.Printf("width: %d, height: %d", config.Width, config.Height) + // 缩放倍数 + scale := 1.0 + if config.Height < shortSide { + shortSide = config.Height + otherSide = config.Width + } + + // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 + if shortSide > 768 { + scale = float64(shortSide) / 768 + shortSide = 768 + } + // 将另一边按照相同的比例缩小,向上取整 + otherSide = int(math.Ceil(float64(otherSide) / scale)) + log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) + // 计算图片的token数量(边的长度除以512,向上取整) + tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) + log.Printf("tiles: %d", tiles) + return tiles*170 + 85, nil +} + func countTokenMessages(messages []Message, model string) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) @@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) { } else { for _, m := range arrayContent { if m.Type == "image_url" { - //TODO: getImageToken - tokenNum += 1000 + imageTokenNum, err := getImageToken(m.ImageUrl) + if err != nil { + return 0, err + } + tokenNum += imageTokenNum + log.Printf("image token num: %d", imageTokenNum) } else { tokenNum += getTokenNum(tokenEncoder, m.Text) } diff --git a/go.mod b/go.mod index a82121b..3a75341 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/chai2010/webp v1.1.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 diff --git a/go.sum b/go.sum index 2d64620..6e7f963 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk= +github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= From 7e0d2606c305f4ac769d17eca1f4fb0be05039d8 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 18:24:37 +0800 Subject: [PATCH 13/15] support gpt-4-1106-vision-preview --- controller/channel-test.go | 7 ++++++- controller/relay-aiproxy.go | 4 ++-- controller/relay-ali.go | 11 ++++++----- controller/relay-baidu.go | 7 ++++--- controller/relay-claude.go | 3 ++- controller/relay-openai.go | 2 +- controller/relay-palm.go | 5 +++-- controller/relay-tencent.go | 7 ++++--- controller/relay-text.go | 6 +++++- controller/relay-utils.go | 31 +++++++++++++++++++++++++------ controller/relay-xunfei.go | 7 ++++--- controller/relay-zhipu.go | 7 ++++--- controller/relay.go | 18 +++++++++++++++--- 13 files changed, 81 insertions(+), 34 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 18f1e9b..c0ac9a6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -86,9 +86,10 @@ func buildTestRequest() *ChatRequest { Model: "", // this will be set later MaxTokens: 1, } + content, _ := json.Marshal("hi") testMessage := Message{ Role: "user", - Content: "hi", + Content: content, } testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest @@ -186,6 +187,10 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) ban = true } + if openaiErr != nil { + err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) + ban = true + } // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce..7dbf679 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = string(request.Messages[len(request.Messages)-1].Content) } return &AIProxyLibraryRequest{ Model: request.Model, @@ -69,7 +69,7 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { } func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { - content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents)) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743..6a79d2b 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: string(message.Content), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = string(message.Content) break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: string(message.Content), + Bot: string(request.Messages[i+1].Content), }) i++ } @@ -184,11 +184,12 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin } func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Output.Text) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Output.Text, + Content: content, }, FinishReason: response.Output.FinishReason, } diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac0..05bbad0 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -109,11 +109,12 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Result) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Result, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7..e131263 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -93,11 +93,12 @@ func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletion } func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { + content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Content: content, Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc0..9b08f85 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + completionTokens += countTokenText(string(choice.Message.Content), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b31..a7b0c1f 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: string(message.Content), } if message.Role == "user" { palmMessage.Author = "0" @@ -76,11 +76,12 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { + content, _ := json.Marshal(candidate.Content) choice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468b..c96e6d4 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: string(message.Content), Role: message.Role, }) } @@ -119,11 +119,12 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { Usage: response.Usage, } if len(response.Choices) > 0 { + content, _ := json.Marshal(response.Choices[0].Messages.Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Choices[0].Messages.Content, + Content: content, }, FinishReason: response.Choices[0].FinishReason, } diff --git a/controller/relay-text.go b/controller/relay-text.go index 2729650..a009267 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -199,9 +199,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } var promptTokens int var completionTokens int + var err error switch relayMode { case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model) + if err != nil { + return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + } case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 40aa547..177d853 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -63,7 +63,8 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func countTokenMessages(messages []Message, model string) int { +func countTokenMessages(messages []Message, model string) (int, error) { + //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -82,15 +83,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) + var arrayContent []MediaMessage + if err := json.Unmarshal(message.Content, &arrayContent); err != nil { + + var stringContent string + if err := json.Unmarshal(message.Content, &stringContent); err != nil { + return 0, err + } else { + tokenNum += getTokenNum(tokenEncoder, stringContent) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + } else { + for _, m := range arrayContent { + if m.Type == "image_url" { + //TODO: getImageToken + tokenNum += 1000 + } else { + tokenNum += getTokenNum(tokenEncoder, m.Text) + } + } } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum + return tokenNum, nil } func countTokenInput(input any, model string) int { diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 91fb604..33383d8 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -112,11 +112,12 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { }, } } + content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Content: content, }, FinishReason: stopFinishReason, } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582..5ad4151 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -144,11 +144,12 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { + content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) openaiChoice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: choice.Role, - Content: strings.Trim(choice.Content, "\""), + Content: content, }, FinishReason: "", } diff --git a/controller/relay.go b/controller/relay.go index 06ef341..714910c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "fmt" "log" "net/http" @@ -12,9 +13,20 @@ import ( ) type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name *string `json:"name,omitempty"` +} + +type MediaMessage struct { + Type string `json:"type"` + Text string `json:"text"` + ImageUrl MessageImageUrl `json:"image_url,omitempty"` +} + +type MessageImageUrl struct { + Url string `json:"url"` + Detail string `json:"detail"` } const ( From e5c2524f152dcc6a39bc2f6369f7df2110475df2 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 20:32:11 +0800 Subject: [PATCH 14/15] fix image token calculate --- controller/relay-utils.go | 74 +++++++++++++++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 ++ 3 files changed, 75 insertions(+), 2 deletions(-) diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 177d853..1873cab 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -2,10 +2,18 @@ package controller import ( "encoding/json" + "errors" "fmt" + "github.com/chai2010/webp" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" + "image" + _ "image/gif" + _ "image/jpeg" + _ "image/png" "io" + "log" + "math" "net/http" "one-api/common" "strconv" @@ -63,6 +71,64 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +func getImageToken(imageUrl MessageImageUrl) (int, error) { + if imageUrl.Detail == "low" { + return 85, nil + } + + response, err := http.Get(imageUrl.Url) + if err != nil { + fmt.Println("Error: Failed to get the URL") + return 0, err + } + + defer response.Body.Close() + + // 限制读取的字节数,防止下载整个图片 + limitReader := io.LimitReader(response.Body, 8192) + + // 读取图片的头部信息来获取图片尺寸 + config, _, err := image.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(gif, jpg, png): %s", err.Error())) + config, err = webp.DecodeConfig(limitReader) + if err != nil { + common.SysLog(fmt.Sprintf("fail to decode image config(webp): %s", err.Error())) + } + } + if config.Width == 0 || config.Height == 0 { + return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", err.Error())) + } + if config.Width < 512 && config.Height < 512 { + if imageUrl.Detail == "auto" || imageUrl.Detail == "" { + return 85, nil + } + } + + shortSide := config.Width + otherSide := config.Height + log.Printf("width: %d, height: %d", config.Width, config.Height) + // 缩放倍数 + scale := 1.0 + if config.Height < shortSide { + shortSide = config.Height + otherSide = config.Width + } + + // 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768 + if shortSide > 768 { + scale = float64(shortSide) / 768 + shortSide = 768 + } + // 将另一边按照相同的比例缩小,向上取整 + otherSide = int(math.Ceil(float64(otherSide) / scale)) + log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale) + // 计算图片的token数量(边的长度除以512,向上取整) + tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512) + log.Printf("tiles: %d", tiles) + return tiles*170 + 85, nil +} + func countTokenMessages(messages []Message, model string) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) @@ -100,8 +166,12 @@ func countTokenMessages(messages []Message, model string) (int, error) { } else { for _, m := range arrayContent { if m.Type == "image_url" { - //TODO: getImageToken - tokenNum += 1000 + imageTokenNum, err := getImageToken(m.ImageUrl) + if err != nil { + return 0, err + } + tokenNum += imageTokenNum + log.Printf("image token num: %d", imageTokenNum) } else { tokenNum += getTokenNum(tokenEncoder, m.Text) } diff --git a/go.mod b/go.mod index a82121b..3a75341 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ module one-api go 1.18 require ( + github.com/chai2010/webp v1.1.1 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 diff --git a/go.sum b/go.sum index 2d64620..6e7f963 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chai2010/webp v1.1.1 h1:jTRmEccAJ4MGrhFOrPMpNGIJ/eybIgwKpcACsrTEapk= +github.com/chai2010/webp v1.1.1/go.mod h1:0XVwvZWdjjdxpUEIf7b9g9VkHFnInUSYujwqTLEuldU= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= From 3f085b612694f943a380f0ab6d508e5b8f07d932 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Fri, 17 Nov 2023 20:47:51 +0800 Subject: [PATCH 15/15] update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 89dfba3..edc34d9 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用情况,方便二次分销 5. 渠道显示已使用额度,支持指定组织访问 6. 分页支持选择每页显示数量 +7. 支持gpt-4-1106-vision-preview,dall-e-3,tts-1 ## 界面截图 ![image](https://github.com/Calcium-Ion/new-api/assets/61247483/3ca0b282-00ff-4c96-bf9d-e29ef615c605)