Merge branch 'main' into latest

This commit is contained in:
CaIon
2023-11-15 21:07:37 +08:00
12 changed files with 118 additions and 66 deletions

View File

@@ -1,52 +1,30 @@
name: Publish Docker image (amd64) name: Docker Image CI
on: on:
push: push:
tags: branches: [ "main" ]
- '*' pull_request:
workflow_dispatch: branches: [ "main" ]
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
push_to_registries:
name: Push Docker image to multiple registries build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps: steps:
- name: Check out the repo - uses: actions/checkout@v3
uses: actions/checkout@v3 - uses: docker/login-action@v3.0.0
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with: with:
username: ${{ secrets.DOCKERHUB_USERNAME }} username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }} password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata (tags, labels) for Docker
id: meta id: meta
uses: docker/metadata-action@v4 uses: docker/metadata-action@v3
with: with:
images: | images: calciumion/neko-api
justsong/one-api - name: Build the Docker image
ghcr.io/${{ github.repository }} uses: docker/build-push-action@v5.0.0
- name: Build and push Docker images
uses: docker/build-push-action@v3
with: with:
context: . context: .
push: true push: true

View File

@@ -37,7 +37,9 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10, "text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"tts-1": 7.5, // 1k characters -> $0.015
"tts-1-hd": 15, // 1k characters -> $0.03
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,

View File

@@ -207,3 +207,12 @@ func String2Int(str string) int {
} }
return num return num
} }
func StringsContains(strs []string, str string) bool {
for _, s := range strs {
if s == str {
return true
}
}
return false
}

View File

@@ -11,10 +11,19 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/model" "one-api/model"
"strings"
) )
var availableVoices = []string{
"alloy",
"echo",
"fable",
"onyx",
"nova",
"shimmer",
}
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
audioModel := "whisper-1"
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel") channelType := c.GetInt("channel")
@@ -22,8 +31,28 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
userId := c.GetInt("id") userId := c.GetInt("id")
group := c.GetString("group") group := c.GetString("group")
var audioRequest AudioRequest
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
// request validation
if audioRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
if strings.HasPrefix(audioRequest.Model, "tts-1") {
if audioRequest.Voice == "" {
return errorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
}
if !common.StringsContains(availableVoices, audioRequest.Voice) {
return errorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
}
}
preConsumedTokens := common.PreConsumedQuota preConsumedTokens := common.PreConsumedQuota
modelRatio := common.GetModelRatio(audioModel) modelRatio := common.GetModelRatio(audioRequest.Model)
groupRatio := common.GetGroupRatio(group) groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio) preConsumedQuota := int(float64(preConsumedTokens) * ratio)
@@ -58,8 +87,8 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
} }
if modelMap[audioModel] != "" { if modelMap[audioRequest.Model] != "" {
audioModel = modelMap[audioModel] audioRequest.Model = modelMap[audioRequest.Model]
} }
} }
@@ -97,9 +126,14 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
defer func(ctx context.Context) { defer func(ctx context.Context) {
go func() { go func() {
quota := countTokenText(audioResponse.Text, audioModel) var quota int
if strings.HasPrefix(audioRequest.Model, "tts-1") {
quota = countAudioToken(audioRequest.Input, audioRequest.Model)
} else {
quota = countAudioToken(audioResponse.Text, audioRequest.Model)
}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota) err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }
@@ -110,7 +144,7 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 { if quota != 0 {
tokenName := c.GetString("token_name") tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent, tokenId) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota) model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)
@@ -127,9 +161,13 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if err != nil { if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
} }
err = json.Unmarshal(responseBody, &audioResponse) if strings.HasPrefix(audioRequest.Model, "tts-1") {
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } else {
err = json.Unmarshal(responseBody, &audioResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
} }
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))

View File

@@ -147,7 +147,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
var textResponse ImageResponse var textResponse ImageResponse
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0) err := model.PostConsumeTokenQuota(tokenId, userId, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -359,7 +359,7 @@ func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
defer func(ctx context.Context) { defer func(ctx context.Context) {
if consumeQuota { if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0) err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true)
if err != nil { if err != nil {
common.SysError("error consuming token remain quota: " + err.Error()) common.SysError("error consuming token remain quota: " + err.Error())
} }

View File

@@ -400,7 +400,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if preConsumedQuota != 0 { if preConsumedQuota != 0 {
go func(ctx context.Context) { go func(ctx context.Context) {
// return pre-consumed quota // return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0) err := model.PostConsumeTokenQuota(tokenId, userQuota, -preConsumedQuota, 0, false)
if err != nil { if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
} }
@@ -434,7 +434,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
quota = 0 quota = 0
} }
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota) err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error()) common.LogError(ctx, "error consuming token remain quota: "+err.Error())
} }

View File

@@ -10,6 +10,7 @@ import (
"one-api/common" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"unicode/utf8"
) )
var stopFinishReason = "stop" var stopFinishReason = "stop"
@@ -106,6 +107,14 @@ func countTokenInput(input any, model string) int {
return 0 return 0
} }
func countAudioToken(text string, model string) int {
if strings.HasPrefix(model, "tts") {
return utf8.RuneCountInString(text)
} else {
return countTokenText(text, model)
}
}
func countTokenText(text string, model string) int { func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model) tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text) return getTokenNum(tokenEncoder, text)

View File

@@ -70,6 +70,12 @@ func (r GeneralOpenAIRequest) ParseInput() []string {
return input return input
} }
type AudioRequest struct {
Model string `json:"model"`
Voice string `json:"voice"`
Input string `json:"input"`
}
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`

View File

@@ -46,9 +46,8 @@ func Distribute() func(c *gin.Context) {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "midjourney" modelRequest.Model = "midjourney"
} }
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
err = common.UnmarshalBodyReusable(c, &modelRequest)
} }
err = common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return return
@@ -70,7 +69,11 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "whisper-1" if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
modelRequest.Model = "tts-1"
} else {
modelRequest.Model = "whisper-1"
}
} }
} }
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"strconv"
"strings" "strings"
) )
@@ -194,22 +195,31 @@ func PreConsumeTokenQuota(tokenId int, quota int) (userQuota int, err error) {
return 0, err return 0, err
} }
if userQuota < quota { if userQuota < quota {
return userQuota, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) return 0, errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota))
} }
if !token.UnlimitedQuota { if !token.UnlimitedQuota {
err = DecreaseTokenQuota(tokenId, quota) err = DecreaseTokenQuota(tokenId, quota)
if err != nil { if err != nil {
return userQuota, err return 0, err
} }
} }
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
return userQuota, err return userQuota - quota, err
} }
func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int) (err error) { func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuota int, sendEmail bool) (err error) {
token, err := GetTokenById(tokenId) token, err := GetTokenById(tokenId)
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
}
if sendEmail {
quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-(quota+preConsumedQuota) < common.QuotaRemindThreshold
noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0 noMoreQuota := userQuota-(quota+preConsumedQuota) <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
@@ -229,16 +239,12 @@ func PostConsumeTokenQuota(tokenId int, userQuota int, quota int, preConsumedQuo
if err != nil { if err != nil {
common.SysError("failed to send email" + err.Error()) common.SysError("failed to send email" + err.Error())
} }
common.SysLog("user quota is low, consumed quota: " + strconv.Itoa(quota) + ", user quota: " + strconv.Itoa(userQuota))
} }
}() }()
} }
err = DecreaseUserQuota(token.UserId, quota)
} else {
err = IncreaseUserQuota(token.UserId, -quota)
}
if err != nil {
return err
} }
if !token.UnlimitedQuota { if !token.UnlimitedQuota {
if quota > 0 { if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota) err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.POST("/audio/transcriptions", controller.Relay) relayV1Router.POST("/audio/transcriptions", controller.Relay)
relayV1Router.POST("/audio/translations", controller.Relay) relayV1Router.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented) relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)