diff --git a/.github/workflows/linux-release.yml b/.github/workflows/linux-release.yml new file mode 100644 index 00000000..e81ab09f --- /dev/null +++ b/.github/workflows/linux-release.yml @@ -0,0 +1,59 @@ +name: Linux Release +permissions: + contents: write + +on: + push: + tags: + - '*' + - '!*-alpha*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web + git describe --tags > VERSION + REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh + cd .. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend (amd64) + run: | + go mod download + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api + + - name: Build Backend (arm64) + run: | + sudo apt-get update + sudo apt-get install gcc-aarch64-linux-gnu + CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64 + + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: | + one-api + one-api-arm64 + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/macos-release.yml b/.github/workflows/macos-release.yml new file mode 100644 index 00000000..13415276 --- /dev/null +++ b/.github/workflows/macos-release.yml @@ -0,0 +1,50 @@ +name: macOS Release +permissions: + contents: write + +on: + push: + tags: + - '*' + - '!*-alpha*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: macos-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web + git describe --tags > VERSION + REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh + cd .. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend + run: | + go mod download + go build -ldflags "-X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api-macos + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: one-api-macos + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/windows-release.yml b/.github/workflows/windows-release.yml new file mode 100644 index 00000000..8b1160b4 --- /dev/null +++ b/.github/workflows/windows-release.yml @@ -0,0 +1,53 @@ +name: Windows Release +permissions: + contents: write + +on: + push: + tags: + - '*' + - '!*-alpha*' + workflow_dispatch: + inputs: + name: + description: 'reason' + required: false +jobs: + release: + runs-on: windows-latest + defaults: + run: + shell: bash + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-node@v3 + with: + node-version: 16 + - name: Build Frontend + env: + CI: "" + run: | + cd web/default + npm install + REACT_APP_VERSION=$(git describe --tags) npm run build + cd ../.. + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.18.0' + - name: Build Backend + run: | + go mod download + go build -ldflags "-s -w -X 'github.com/songquanpeng/one-api/common.Version=$(git describe --tags)'" -o one-api.exe + - name: Release + uses: softprops/action-gh-release@v1 + if: startsWith(github.ref, 'refs/tags/') + with: + files: one-api.exe + draft: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/common/constants.go b/common/constants.go index ccaa3560..ac901139 100644 --- a/common/constants.go +++ b/common/constants.go @@ -64,6 +64,9 @@ const ( ChannelTypeTencent = 23 ChannelTypeGemini = 24 ChannelTypeMoonshot = 25 + ChannelTypeBaichuan = 26 + ChannelTypeMinimax = 27 + ChannelTypeMistral = 28 ) var ChannelBaseURLs = []string{ @@ -93,6 +96,9 @@ var ChannelBaseURLs = []string{ "https://hunyuan.cloud.tencent.com", // 23 "https://generativelanguage.googleapis.com", // 24 "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 } const ( diff --git a/common/logger/logger.go b/common/logger/logger.go index f970ee61..8232b2fc 100644 --- a/common/logger/logger.go +++ b/common/logger/logger.go @@ -13,6 +13,7 @@ import ( ) const ( + loggerDEBUG = "DEBUG" loggerINFO = "INFO" loggerWarn = "WARN" loggerError = "ERR" @@ -55,6 +56,10 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } +func Debug(ctx context.Context, msg string) { + logHelper(ctx, loggerDEBUG, msg) +} + func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } @@ -67,6 +72,10 @@ func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Debugf(ctx context.Context, format string, a ...any) { + Debug(ctx, fmt.Sprintf(format, a...)) +} + func Infof(ctx context.Context, format string, a ...any) { Info(ctx, fmt.Sprintf(format, a...)) } diff --git a/common/model-ratio.go b/common/model-ratio.go index 916eda6f..3c064574 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -7,29 +7,6 @@ import ( "time" ) -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} - const ( USD2RMB = 7 USD = 500 // $0.002 = 1 -> $1 = 500 @@ -40,7 +17,6 @@ const ( // https://platform.openai.com/docs/models/model-endpoint-compatibility // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf // https://openai.com/pricing -// TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens var ModelRatio = map[string]float64{ @@ -94,14 +70,18 @@ var ModelRatio = map[string]float64{ "claude-2.0": 5.51, // $11.02 / 1M tokens "claude-2.1": 5.51, // $11.02 / 1M tokens // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8k": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens + "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens + "ERNIE-Bot-8k": 0.024 * RMB, + "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens + "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + // https://open.bigmodel.cn/pricing + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens @@ -127,6 +107,37 @@ var ModelRatio = map[string]float64{ "moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-128k": 0.06 * RMB, + // https://platform.baichuan-ai.com/price + "Baichuan2-Turbo": 0.008 * RMB, + "Baichuan2-Turbo-192k": 0.016 * RMB, + "Baichuan2-53B": 0.02 * RMB, + // https://api.minimax.chat/document/price + "abab6-chat": 0.1 * RMB, + "abab5.5-chat": 0.015 * RMB, + "abab5.5s-chat": 0.005 * RMB, + // https://docs.mistral.ai/platform/pricing/ + "open-mistral-7b": 0.25 / 1000 * USD, + "open-mixtral-8x7b": 0.7 / 1000 * USD, + "mistral-small-latest": 2.0 / 1000 * USD, + "mistral-medium-latest": 2.7 / 1000 * USD, + "mistral-large-latest": 8.0 / 1000 * USD, + "mistral-embed": 0.1 / 1000 * USD, +} + +var CompletionRatio = map[string]float64{} + +var DefaultModelRatio map[string]float64 +var DefaultCompletionRatio map[string]float64 + +func init() { + DefaultModelRatio = make(map[string]float64) + for k, v := range ModelRatio { + DefaultModelRatio[k] = v + } + DefaultCompletionRatio = make(map[string]float64) + for k, v := range CompletionRatio { + DefaultCompletionRatio[k] = v + } } func ModelRatio2JSONString() string { @@ -147,6 +158,9 @@ func GetModelRatio(name string) float64 { name = strings.TrimSuffix(name, "-internet") } ratio, ok := ModelRatio[name] + if !ok { + ratio, ok = DefaultModelRatio[name] + } if !ok { logger.SysError("model ratio not found: " + name) return 30 @@ -154,8 +168,6 @@ func GetModelRatio(name string) float64 { return ratio } -var CompletionRatio = map[string]float64{} - func CompletionRatio2JSONString() string { jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { @@ -176,6 +188,9 @@ func GetCompletionRatio(name string) float64 { if ratio, ok := CompletionRatio[name]; ok { return ratio } + if ratio, ok := DefaultCompletionRatio[name]; ok { + return ratio + } if strings.HasPrefix(name, "gpt-3.5") { if strings.HasSuffix(name, "0125") { // https://openai.com/blog/new-embedding-models-and-api-updates @@ -208,5 +223,8 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "claude-2") { return 2.965517 } + if strings.HasPrefix(name, "mistral-") { + return 3 + } return 1 } diff --git a/common/random.go b/common/random.go new file mode 100644 index 00000000..44bd2856 --- /dev/null +++ b/common/random.go @@ -0,0 +1,8 @@ +package common + +import "math/rand" + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/channel-test.go b/controller/channel-test.go index b498f4f1..7007e205 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -18,6 +19,7 @@ import ( "net/http/httptest" "net/url" "strconv" + "strings" "sync" "time" @@ -51,6 +53,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Request.Header.Set("Content-Type", "application/json") c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) + middleware.SetupContextForSelectedChannel(c, channel, "") meta := util.GetRelayMeta(c) apiType := constant.ChannelType2APIType(channel.Type) adaptor := helper.GetAdaptor(apiType) @@ -59,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error } adaptor.Init(meta) modelName := adaptor.GetModelList()[0] + if !strings.Contains(channel.Models, modelName) { + modelNames := strings.Split(channel.Models, ",") + if len(modelNames) > 0 { + modelName = modelNames[0] + } + } request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName diff --git a/controller/model.go b/controller/model.go index 09d205c8..ae10ff32 100644 --- a/controller/model.go +++ b/controller/model.go @@ -4,6 +4,9 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/helper" @@ -102,6 +105,39 @@ func init() { Parent: nil, }) } + for _, modelName := range baichuan.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "baichuan", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } + for _, modelName := range minimax.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "minimax", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } + for _, modelName := range mistral.ModelList { + openAIModels = append(openAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "mistralai", + Permission: permission, + Root: modelName, + Parent: nil, + }) + } openAIModelsMap = make(map[string]OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model diff --git a/controller/relay.go b/controller/relay.go index 6c03fefe..7cbbc970 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -41,6 +41,10 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := constant.Path2RelayMode(c.Request.URL.Path) + if config.DebugEnabled { + requestBody, _ := common.GetRequestBody(c) + logger.Debugf(ctx, "request body: %s", string(requestBody)) + } bizErr := relay(c, relayMode) if bizErr == nil { return @@ -58,7 +62,7 @@ func Relay(c *gin.Context) { retryTimes = 0 } for i := retryTimes; i > 0; i-- { - channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel) + channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes) if err != nil { logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err) break diff --git a/middleware/distributor.go b/middleware/distributor.go index 4f77e786..33926a54 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -68,7 +68,7 @@ func Distribute() func(c *gin.Context) { } } requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model) + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { diff --git a/model/cache.go b/model/cache.go index 04a60348..3c3575b8 100644 --- a/model/cache.go +++ b/model/cache.go @@ -191,7 +191,7 @@ func SyncChannelCache(frequency int) { } } -func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { +func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) { if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } @@ -213,5 +213,10 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } idx := rand.Intn(endIdx) + if ignoreFirstPriority { + if endIdx < len(channels) { // which means there are more than one priority + idx = common.RandRange(endIdx, len(channels)) + } + } return channels[idx], nil } diff --git a/relay/channel/baichuan/constants.go b/relay/channel/baichuan/constants.go new file mode 100644 index 00000000..cb20a1ff --- /dev/null +++ b/relay/channel/baichuan/constants.go @@ -0,0 +1,7 @@ +package baichuan + +var ModelList = []string{ + "Baichuan2-Turbo", + "Baichuan2-Turbo-192k", + "Baichuan-Text-Embedding", +} diff --git a/relay/channel/gemini/constants.go b/relay/channel/gemini/constants.go index 5bb0c168..4e7c57f9 100644 --- a/relay/channel/gemini/constants.go +++ b/relay/channel/gemini/constants.go @@ -1,6 +1,6 @@ package gemini var ModelList = []string{ - "gemini-pro", - "gemini-pro-vision", + "gemini-pro", "gemini-1.0-pro-001", + "gemini-pro-vision", "gemini-1.0-pro-vision-001", } diff --git a/relay/channel/minimax/constants.go b/relay/channel/minimax/constants.go new file mode 100644 index 00000000..c3da5b2d --- /dev/null +++ b/relay/channel/minimax/constants.go @@ -0,0 +1,7 @@ +package minimax + +var ModelList = []string{ + "abab5.5s-chat", + "abab5.5-chat", + "abab6-chat", +} diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go new file mode 100644 index 00000000..a01821c2 --- /dev/null +++ b/relay/channel/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/util" +) + +func GetRequestURL(meta *util.RelayMeta) (string, error) { + if meta.Mode == constant.RelayModeChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/channel/mistral/constants.go b/relay/channel/mistral/constants.go new file mode 100644 index 00000000..cdb157f5 --- /dev/null +++ b/relay/channel/mistral/constants.go @@ -0,0 +1,10 @@ +package mistral + +var ModelList = []string{ + "open-mistral-7b", + "open-mixtral-8x7b", + "mistral-small-latest", + "mistral-medium-latest", + "mistral-large-latest", + "mistral-embed", +} diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 1313e317..5a04a768 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -7,6 +7,9 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel/ai360" + "github.com/songquanpeng/one-api/relay/channel/baichuan" + "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/channel/mistral" "github.com/songquanpeng/one-api/relay/channel/moonshot" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -24,7 +27,8 @@ func (a *Adaptor) Init(meta *util.RelayMeta) { } func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.ChannelType == common.ChannelTypeAzure { + switch meta.ChannelType { + case common.ChannelTypeAzure: // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) @@ -38,8 +42,11 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case common.ChannelTypeMinimax: + return minimax.GetRequestURL(meta) + default: + return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } - return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { @@ -70,7 +77,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string - err, responseText = StreamHandler(c, resp, meta.Mode) + err, responseText, _ = StreamHandler(c, resp, meta.Mode) usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } else { err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) @@ -84,6 +91,12 @@ func (a *Adaptor) GetModelList() []string { return ai360.ModelList case common.ChannelTypeMoonshot: return moonshot.ModelList + case common.ChannelTypeBaichuan: + return baichuan.ModelList + case common.ChannelTypeMinimax: + return minimax.ModelList + case common.ChannelTypeMistral: + return mistral.ModelList default: return ModelList } @@ -97,6 +110,12 @@ func (a *Adaptor) GetChannelName() string { return "360" case common.ChannelTypeMoonshot: return "moonshot" + case common.ChannelTypeBaichuan: + return "baichuan" + case common.ChannelTypeMinimax: + return "minimax" + case common.ChannelTypeMistral: + return "mistralai" default: return "openai" } diff --git a/relay/channel/openai/main.go b/relay/channel/openai/main.go index fbe55cf9..d47cd164 100644 --- a/relay/channel/openai/main.go +++ b/relay/channel/openai/main.go @@ -14,7 +14,7 @@ import ( "strings" ) -func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string) { +func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -31,6 +31,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) dataChan := make(chan string) stopChan := make(chan bool) + var usage *model.Usage go func() { for scanner.Scan() { data := scanner.Text() @@ -54,6 +55,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E for _, choice := range streamResponse.Choices { responseText += choice.Delta.Content } + if streamResponse.Usage != nil { + usage = streamResponse.Usage + } case constant.RelayModeCompletions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) @@ -86,9 +90,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E }) err := resp.Body.Close() if err != nil { - return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - return nil, responseText + return nil, responseText, usage } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { diff --git a/relay/channel/openai/model.go b/relay/channel/openai/model.go index b24485a8..6c0b2c53 100644 --- a/relay/channel/openai/model.go +++ b/relay/channel/openai/model.go @@ -132,6 +132,7 @@ type ChatCompletionsStreamResponse struct { Created int64 `json:"created"` Model string `json:"model"` Choices []ChatCompletionsStreamResponseChoice `json:"choices"` + Usage *model.Usage `json:"usage"` } type CompletionsStreamResponse struct { diff --git a/relay/constant/image.go b/relay/constant/image.go new file mode 100644 index 00000000..5e04895f --- /dev/null +++ b/relay/constant/image.go @@ -0,0 +1,24 @@ +package constant + +var DalleSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, +} + +var DalleGenerationImageAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. +} + +var DalleImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index e3745372..89fc69ce 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -36,6 +36,65 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener return textRequest, nil } +func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { + imageRequest := &openai.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) + if err != nil { + return nil, err + } + if imageRequest.N == 0 { + imageRequest.N = 1 + } + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" + } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" + } + return imageRequest, nil +} + +func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { + // model validation + _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) + } + // check prompt length + if imageRequest.Prompt == "" { + return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) + } + if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) + } + // Number of generated images validation + if !isWithinRange(imageRequest.Model, imageRequest.N) { + // channel not azure + if meta.ChannelType != common.ChannelTypeAzure { + return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) + } + } + return nil +} + +func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { + if imageRequest == nil { + return 0, errors.New("imageRequest is nil") + } + imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + if !hasValidSize { + return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size) + } + if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { + if imageRequest.Size == "1024x1024" { + imageCostRatio *= 2 + } else { + imageCostRatio *= 1.5 + } + } + return imageCostRatio, nil +} + func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { case constant.RelayModeChatCompletions: @@ -113,13 +172,6 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R if err != nil { logger.Error(ctx, "error update user quota cache: "+err.Error()) } - // if quota != 0 { - // logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) - // model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) - // model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) - // model.UpdateChannelUsedQuota(meta.ChannelId, quota) - // } - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) diff --git a/relay/controller/image.go b/relay/controller/image.go index 4e0ed172..339505b6 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -14,6 +14,7 @@ import ( "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/util" @@ -21,121 +22,66 @@ import ( ) func isWithinRange(element string, value int) bool { - if _, ok := common.DalleGenerationImageAmounts[element]; !ok { + if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { return false } - min := common.DalleGenerationImageAmounts[element][0] - max := common.DalleGenerationImageAmounts[element][1] + min := constant.DalleGenerationImageAmounts[element][0] + max := constant.DalleGenerationImageAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { - imageModel := "dall-e-2" - imageSize := "1024x1024" - - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - // group := c.GetString("group") - - var imageRequest openai.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) + ctx := c.Request.Context() + meta := util.GetRelayMeta(c) + imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { - return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - - if imageRequest.N == 0 { - imageRequest.N = 1 - } - - // Size validation - if imageRequest.Size != "" { - imageSize = imageRequest.Size - } - - // Model validation - if imageRequest.Model != "" { - imageModel = imageRequest.Model - } - - imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize] - - // Check if model is supported - if hasValidSize { - if imageRequest.Quality == "hd" && imageModel == "dall-e-3" { - if imageSize == "1024x1024" { - imageCostRatio *= 2 - } else { - imageCostRatio *= 1.5 - } - } - } else { - return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) - } - - // Prompt validation - if imageRequest.Prompt == "" { - return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) - } - - // Check prompt length - if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] { - return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) - } - - // Number of generated images validation - if !isWithinRange(imageModel, imageRequest.N) { - // channel not azure - if channelType != common.ChannelTypeAzure { - return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) - } + logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[imageModel] != "" { - imageModel = modelMap[imageModel] - isModelMapped = true - } + var isModelMapped bool + meta.OriginModelName = imageRequest.Model + imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) + meta.ActualModelName = imageRequest.Model + + // model validation + bizErr := validateImageRequest(imageRequest, meta) + if bizErr != nil { + return bizErr } - baseURL := common.ChannelBaseURLs[channelType] + + imageCostRatio, err := getImageCostRatio(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) + } + requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { + fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) + if meta.ChannelType == common.ChannelTypeAzure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api apiVersion := util.GetAzureAPIVersion(c) // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) } var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { - return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) } else { requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageModel) - // groupRatio := common.GetGroupRatio(group) - groupRatio := c.GetFloat64("channel_ratio") + modelRatio := common.GetModelRatio(imageRequest.Model) + // groupRatio := common.GetGroupRatio(meta.Group) + groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio ratio := modelRatio * groupRatio - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(meta.UserId) quota := int(ratio*imageCostRatio*1000) * imageRequest.N @@ -148,7 +94,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication + if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication token = strings.TrimPrefix(token, "Bearer ") req.Header.Set("api-key", token) } else { @@ -171,25 +117,25 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - var textResponse openai.ImageResponse + var imageResponse openai.ImageResponse defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return } - err := model.PostConsumeTokenQuota(tokenId, quota) + err := model.PostConsumeTokenQuota(meta.TokenId, quota) if err != nil { logger.SysError("error consuming token remain quota: " + err.Error()) } - err = model.CacheUpdateUserQuota(userId) + err = model.CacheUpdateUserQuota(meta.UserId) if err != nil { logger.SysError("error update user quota cache: " + err.Error()) } if quota != 0 { tokenName := c.GetString("token_name") logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) + model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) channelId := c.GetInt("channel_id") model.UpdateChannelUsedQuota(channelId, quota) } @@ -204,7 +150,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } - err = json.Unmarshal(responseBody, &textResponse) + err = json.Unmarshal(responseBody, &imageResponse) if err != nil { return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } diff --git a/relay/controller/text.go b/relay/controller/text.go index 46c62a3c..9dd0e5a5 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -56,7 +56,8 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { var requestBody io.Reader if meta.APIType == constant.APITypeOpenAI { // no need to convert request for openai - if isModelMapped { + shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan + if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError) diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index aeff5190..31c45048 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -29,6 +29,12 @@ export const CHANNEL_OPTIONS = { value: 24, color: 'orange' }, + 28: { + key: 28, + text: 'Mistral AI', + value: 28, + color: 'orange' + }, 15: { key: 15, text: '百度文心千帆', @@ -71,6 +77,18 @@ export const CHANNEL_OPTIONS = { value: 23, color: 'default' }, + 26: { + key: 26, + text: '百川大模型', + value: 26, + color: 'default' + }, + 27: { + key: 27, + text: 'MiniMax', + value: 27, + color: 'default' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/views/Channel/type/Config.js b/web/berry/src/views/Channel/type/Config.js index a091c8d6..4dec33de 100644 --- a/web/berry/src/views/Channel/type/Config.js +++ b/web/berry/src/views/Channel/type/Config.js @@ -67,7 +67,7 @@ const typeConfig = { }, 16: { input: { - models: ["chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], + models: ["glm-4", "glm-4v", "glm-3-turbo", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite"], }, modelGroup: "zhipu", }, @@ -145,6 +145,24 @@ const typeConfig = { }, modelGroup: "google gemini", }, + 25: { + input: { + models: ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'], + }, + modelGroup: "moonshot", + }, + 26: { + input: { + models: ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'], + }, + modelGroup: "baichuan", + }, + 27: { + input: { + models: ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'], + }, + modelGroup: "minimax", + }, }; export { defaultConfig, typeConfig }; diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 16da1b97..b21bb15d 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -4,6 +4,7 @@ export const CHANNEL_OPTIONS = [ { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, @@ -11,6 +12,8 @@ export const CHANNEL_OPTIONS = [ { key: 19, text: '360 智脑', value: 19, color: 'blue' }, { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 9cee664d..5811604e 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -79,7 +79,7 @@ const EditChannel = () => { localModels = [...localModels, ...withInternetVersion]; break; case 16: - localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; + localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite']; break; case 18: localModels = [ @@ -102,6 +102,12 @@ const EditChannel = () => { case 25: localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k']; break; + case 26: + localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding']; + break; + case 27: + localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); }