Compare commits

..

1 Commits

Author SHA1 Message Date
Laisky.Cai
d4703dfd97 feat: add Proxy channel type and relay mode
Add the Proxy channel type and relay mode to support proxying requests to custom upstream services.
2024-07-21 13:42:13 +00:00
156 changed files with 2275 additions and 6399 deletions

View File

@@ -1,17 +1,19 @@
name: CI name: CI
# This setup assumes that you run the unit tests with code coverage in the same # This setup assumes that you run the unit tests with code coverage in the same
# workflow that will also print the coverage report as comment to the pull request. # workflow that will also print the coverage report as comment to the pull request.
# Therefore, you need to trigger this workflow when a pull request is (re)opened or # Therefore, you need to trigger this workflow when a pull request is (re)opened or
# when new code is pushed to the branch of the pull request. In addition, you also # when new code is pushed to the branch of the pull request. In addition, you also
# need to trigger this workflow when new code is pushed to the main branch because # need to trigger this workflow when new code is pushed to the main branch because
# we need to upload the code coverage results as artifact for the main branch as # we need to upload the code coverage results as artifact for the main branch as
# well since it will be the baseline code coverage. # well since it will be the baseline code coverage.
# #
# We do not want to trigger the workflow for pushes to *any* branch because this # We do not want to trigger the workflow for pushes to *any* branch because this
# would trigger our jobs twice on pull requests (once from "push" event and once # would trigger our jobs twice on pull requests (once from "push" event and once
# from "pull_request->synchronize") # from "pull_request->synchronize")
on: on:
pull_request:
types: [opened, reopened, synchronize]
push: push:
branches: branches:
- 'main' - 'main'
@@ -29,7 +31,7 @@ jobs:
with: with:
go-version: ^1.22 go-version: ^1.22
# When you execute your unit tests, make sure to use the "-coverprofile" flag to write a # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a
# coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt")
# in the next step as well as the next job. # in the next step as well as the next job.
- name: Test - name: Test

View File

@@ -1,4 +1,4 @@
name: Publish Docker image (English) name: Publish Docker image (amd64, English)
on: on:
push: push:
@@ -34,13 +34,6 @@ jobs:
- name: Translate - name: Translate
run: | run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Log in to Docker Hub - name: Log in to Docker Hub
uses: docker/login-action@v2 uses: docker/login-action@v2
with: with:
@@ -58,7 +51,6 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }} labels: ${{ steps.meta.outputs.labels }}

View File

@@ -0,0 +1,61 @@
name: Publish Docker image (amd64)
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs:
push_to_registries:
name: Push Docker image to multiple registries
runs-on: ubuntu-latest
permissions:
packages: write
contents: read
steps:
- name: Check out the repo
uses: actions/checkout@v3
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 1
fi
- name: Save version info
run: |
git describe --tags > VERSION
- name: Log in to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Log in to the Container registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v4
with:
images: |
justsong/one-api
ghcr.io/${{ github.repository }}
- name: Build and push Docker images
uses: docker/build-push-action@v3
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -1,9 +1,10 @@
name: Publish Docker image name: Publish Docker image (arm64)
on: on:
push: push:
tags: tags:
- 'v*.*.*' - 'v*.*.*'
- '!*-alpha*'
workflow_dispatch: workflow_dispatch:
inputs: inputs:
name: name:

3
.gitignore vendored
View File

@@ -9,5 +9,4 @@ logs
data data
/web/node_modules /web/node_modules
cmd.md cmd.md
.env .env
/one-api

View File

@@ -1,4 +1,4 @@
FROM --platform=$BUILDPLATFORM node:16 AS builder FROM node:16 as builder
WORKDIR /web WORKDIR /web
COPY ./VERSION . COPY ./VERSION .

View File

@@ -89,8 +89,6 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [DeepL](https://www.deepl.com/) + [x] [DeepL](https://www.deepl.com/)
+ [x] [together.ai](https://www.together.ai/) + [x] [together.ai](https://www.together.ai/)
+ [x] [novita.ai](https://www.novita.ai/) + [x] [novita.ai](https://www.novita.ai/)
+ [x] [硅基流动 SiliconCloud](https://siliconflow.cn/siliconcloud)
+ [x] [xAI](https://x.ai/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
@@ -115,8 +113,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
21. 支持 Cloudflare Turnstile 用户校验。 21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式** 22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持[飞书授权登录](https://open.feishu.cn/document/uAjLw4CM/ukTMukTMukTM/reference/authen-v1/authorize/get)[这里有 One API 的实现细节阐述供参考](https://iamazing.cn/page/feishu-oauth-login) + 支持使用飞书进行授权登录
+ 支持 [GitHub 授权登录](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。 23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。 24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
@@ -175,10 +173,6 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456` 初始账号用户名为 `root`,密码为 `123456`
### 通过宝塔面板进行一键部署
1. 安装宝塔面板9.2.0及以上版本,前往 [宝塔面板](https://www.bt.cn/new/download.html?r=dk_oneapi) 官网,选择正式版的脚本下载安装;
2. 安装后登录宝塔面板,在左侧菜单栏中点击 `Docker`,首次进入会提示安装 `Docker` 服务,点击立即安装,按提示完成安装;
3. 安装完成后在应用商店中搜索 `One-API`,点击安装,配置域名等基本信息即可完成安装;
### 基于 Docker Compose 进行部署 ### 基于 Docker Compose 进行部署
@@ -222,7 +216,7 @@ docker-compose ps
3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。 3. 所有从服务器必须设置 `NODE_TYPE` 为 `slave`,不设置则默认为主服务器。
4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis无论主从。 4. 设置 `SYNC_FREQUENCY` 后服务器将定期从数据库同步配置,在使用远程数据库的情况下,推荐设置该项并启用 Redis无论主从。
5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。 5. 从服务器可以选择设置 `FRONTEND_BASE_URL`,以重定向页面请求到主服务器。
6. 从服务器上**分别**装好 Redis设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟Redis 集群或者哨兵模式的支持请参考环境变量说明) 6. 从服务器上**分别**装好 Redis设置好 `REDIS_CONN_STRING`,这样可以做到在缓存未过期的情况下数据库零访问,可以减少延迟。
7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。 7. 如果主服务器访问数据库延迟也比较高,则也需要启用 Redis并设置 `SYNC_FREQUENCY`,以定期从数据库同步配置。
环境变量的具体使用方法详见[此处](#环境变量)。 环境变量的具体使用方法详见[此处](#环境变量)。
@@ -257,9 +251,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
#### QChatGPT - QQ机器人 #### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT 项目主页https://github.com/RockChinQ/QChatGPT
根据[文档](https://qchatgpt.rockchin.top)完成部署后,在 `data/provider.json`设置`requester.openai-chat-completions.base-url`为 One API 实例地址,并填写 API Key 到 `keys.openai` 组中,设置 `model` 为要使用的模型名称。 根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
运行期间可以通过`!model`命令查看、切换可用模型。 可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
### 部署到第三方平台 ### 部署到第三方平台
<details> <details>
@@ -351,11 +345,6 @@ graph LR
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。 + 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
+ 如果需要使用哨兵或者集群模式:
+ 则需要把该环境变量设置为节点列表,例如:`localhost:49153,localhost:49154,localhost:49155`。
+ 除此之外还需要设置以下环境变量:
+ `REDIS_PASSWORD`Redis 集群或者哨兵模式下的密码设置。
+ `REDIS_MASTER_NAME`Redis 哨兵模式下主节点的名称。
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string` + 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL。
@@ -409,8 +398,6 @@ graph LR
26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 26. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。
27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。
29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage默认不开启可选值为 `true` 和 `false`。
30. `TEST_PROMPT`:测试模型时的用户 prompt默认为 `Print your model name exactly and do not output without any other text.`。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -1,14 +1,13 @@
package config package config
import ( import (
"github.com/songquanpeng/one-api/common/env"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/songquanpeng/one-api/common/env"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -36,7 +35,6 @@ var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false var GitHubOAuthEnabled = false
var OidcEnabled = false
var WeChatAuthEnabled = false var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false var TurnstileCheckEnabled = false
var RegisterEnabled = true var RegisterEnabled = true
@@ -72,13 +70,6 @@ var GitHubClientSecret = ""
var LarkClientId = "" var LarkClientId = ""
var LarkClientSecret = "" var LarkClientSecret = ""
var OidcClientId = ""
var OidcClientSecret = ""
var OidcWellKnown = ""
var OidcAuthorizationEndpoint = ""
var OidcTokenEndpoint = ""
var OidcUserinfoEndpoint = ""
var WeChatServerAddress = "" var WeChatServerAddress = ""
var WeChatServerToken = "" var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = "" var WeChatAccountQRCodeImageURL = ""
@@ -161,6 +152,3 @@ var OnlyOneLogFile = env.Bool("ONLY_ONE_LOG_FILE", false)
var RelayProxy = env.String("RELAY_PROXY", "") var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.")

View File

@@ -20,5 +20,4 @@ const (
BaseURL = "base_url" BaseURL = "base_url"
AvailableModels = "available_models" AvailableModels = "available_models"
KeyRequestBody = "key_request_body" KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
) )

View File

@@ -31,15 +31,15 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type") contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") { if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v) err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else { } else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) // skip for now
err = c.ShouldBind(&v) // TODO: someday non json request have variant model, we will need to implementation this
} }
if err != nil { if err != nil {
return err return err
} }
// Reset request body // Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil return nil
} }

View File

@@ -1,8 +1,9 @@
package helper package helper
import ( import (
"context"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
"net" "net"
@@ -10,10 +11,6 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/random"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@@ -109,18 +106,6 @@ func GenRequestID() string {
return GetTimeString() + random.GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func SetRequestID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, RequestIdKey, id)
}
func GetRequestID(ctx context.Context) string {
rawRequestId := ctx.Value(RequestIdKey)
if rawRequestId == nil {
return ""
}
return rawRequestId.(string)
}
func GetResponseID(c *gin.Context) string { func GetResponseID(c *gin.Context) string {
logID := c.GetString(RequestIdKey) logID := c.GetString(RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)
@@ -152,23 +137,3 @@ func String2Int(str string) int {
} }
return num return num
} }
func Float64PtrMax(p *float64, maxValue float64) *float64 {
if p == nil {
return nil
}
if *p > maxValue {
return &maxValue
}
return p
}
func Float64PtrMin(p *float64, minValue float64) *float64 {
if p == nil {
return nil
}
if *p < minValue {
return &minValue
}
return p
}

View File

@@ -13,8 +13,3 @@ func GetTimeString() string {
now := time.Now() now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
} }
// CalcElapsedTime return the elapsed time in milliseconds (ms)
func CalcElapsedTime(start time.Time) int64 {
return time.Now().Sub(start).Milliseconds()
}

View File

@@ -7,25 +7,19 @@ import (
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
) )
type loggerLevel string
const ( const (
loggerDEBUG loggerLevel = "DEBUG" loggerDEBUG = "DEBUG"
loggerINFO loggerLevel = "INFO" loggerINFO = "INFO"
loggerWarn loggerLevel = "WARN" loggerWarn = "WARN"
loggerError loggerLevel = "ERROR" loggerError = "ERR"
loggerFatal loggerLevel = "FATAL"
) )
var setupLogOnce sync.Once var setupLogOnce sync.Once
@@ -50,26 +44,27 @@ func SetupLogger() {
} }
func SysLog(s string) { func SysLog(s string) {
logHelper(nil, loggerINFO, s) t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysLogf(format string, a ...any) { func SysLogf(format string, a ...any) {
logHelper(nil, loggerINFO, fmt.Sprintf(format, a...)) SysLog(fmt.Sprintf(format, a...))
} }
func SysError(s string) { func SysError(s string) {
logHelper(nil, loggerError, s) t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
} }
func SysErrorf(format string, a ...any) { func SysErrorf(format string, a ...any) {
logHelper(nil, loggerError, fmt.Sprintf(format, a...)) SysError(fmt.Sprintf(format, a...))
} }
func Debug(ctx context.Context, msg string) { func Debug(ctx context.Context, msg string) {
if !config.DebugEnabled { if config.DebugEnabled {
return logHelper(ctx, loggerDEBUG, msg)
} }
logHelper(ctx, loggerDEBUG, msg)
} }
func Info(ctx context.Context, msg string) { func Info(ctx context.Context, msg string) {
@@ -85,65 +80,37 @@ func Error(ctx context.Context, msg string) {
} }
func Debugf(ctx context.Context, format string, a ...any) { func Debugf(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerDEBUG, fmt.Sprintf(format, a...)) Debug(ctx, fmt.Sprintf(format, a...))
} }
func Infof(ctx context.Context, format string, a ...any) { func Infof(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerINFO, fmt.Sprintf(format, a...)) Info(ctx, fmt.Sprintf(format, a...))
} }
func Warnf(ctx context.Context, format string, a ...any) { func Warnf(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerWarn, fmt.Sprintf(format, a...)) Warn(ctx, fmt.Sprintf(format, a...))
} }
func Errorf(ctx context.Context, format string, a ...any) { func Errorf(ctx context.Context, format string, a ...any) {
logHelper(ctx, loggerError, fmt.Sprintf(format, a...)) Error(ctx, fmt.Sprintf(format, a...))
} }
func FatalLog(s string) { func logHelper(ctx context.Context, level string, msg string) {
logHelper(nil, loggerFatal, s)
}
func FatalLogf(format string, a ...any) {
logHelper(nil, loggerFatal, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level loggerLevel, msg string) {
writer := gin.DefaultErrorWriter writer := gin.DefaultErrorWriter
if level == loggerINFO { if level == loggerINFO {
writer = gin.DefaultWriter writer = gin.DefaultWriter
} }
var requestId string id := ctx.Value(helper.RequestIdKey)
if ctx != nil { if id == nil {
rawRequestId := helper.GetRequestID(ctx) id = helper.GenRequestID()
if rawRequestId != "" {
requestId = fmt.Sprintf(" | %s", rawRequestId)
}
} }
lineInfo, funcName := getLineInfo()
now := time.Now() now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v%s%s %s%s \n", level, now.Format("2006/01/02 - 15:04:05"), requestId, lineInfo, funcName, msg) _, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
SetupLogger() SetupLogger()
if level == loggerFatal {
os.Exit(1)
}
} }
func getLineInfo() (string, string) { func FatalLog(v ...any) {
funcName := "[unknown] " t := time.Now()
pc, file, line, ok := runtime.Caller(3) _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
if ok { os.Exit(1)
if fn := runtime.FuncForPC(pc); fn != nil {
parts := strings.Split(fn.Name(), ".")
funcName = "[" + parts[len(parts)-1] + "] "
}
} else {
file = "unknown"
line = 0
}
parts := strings.Split(file, "one-api/")
if len(parts) > 1 {
file = parts[1]
}
return fmt.Sprintf(" | %s:%d", file, line), funcName
} }

View File

@@ -2,15 +2,13 @@ package common
import ( import (
"context" "context"
"os"
"strings"
"time"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"os"
"time"
) )
var RDB redis.Cmdable var RDB *redis.Client
var RedisEnabled = true var RedisEnabled = true
// InitRedisClient This function is called after init() // InitRedisClient This function is called after init()
@@ -25,23 +23,13 @@ func InitRedisClient() (err error) {
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil return nil
} }
redisConnString := os.Getenv("REDIS_CONN_STRING") logger.SysLog("Redis is enabled")
if os.Getenv("REDIS_MASTER_NAME") == "" { opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
logger.SysLog("Redis is enabled") if err != nil {
opt, err := redis.ParseURL(redisConnString) logger.FatalLog("failed to parse Redis connection string: " + err.Error())
if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error())
}
RDB = redis.NewClient(opt)
} else {
// cluster mode
logger.SysLog("Redis cluster mode enabled")
RDB = redis.NewUniversalClient(&redis.UniversalOptions{
Addrs: strings.Split(redisConnString, ","),
Password: os.Getenv("REDIS_PASSWORD"),
MasterName: os.Getenv("REDIS_MASTER_NAME"),
})
} }
RDB = redis.NewClient(opt)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()

View File

@@ -3,10 +3,9 @@ package render
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"strings"
) )
func StringData(c *gin.Context, str string) { func StringData(c *gin.Context, str string) {

View File

@@ -5,18 +5,16 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type GitHubOAuthResponse struct { type GitHubOAuthResponse struct {
@@ -83,7 +81,6 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
} }
func GitHubOAuth(c *gin.Context) { func GitHubOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@@ -139,7 +136,7 @@ func GitHubOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(ctx, 0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -5,17 +5,15 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type LarkOAuthResponse struct { type LarkOAuthResponse struct {
@@ -42,7 +40,7 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
req, err := http.NewRequest("POST", "https://open.feishu.cn/open-apis/authen/v2/oauth/token", bytes.NewBuffer(jsonData)) req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -81,7 +79,6 @@ func getLarkUserInfoByCode(code string) (*LarkUser, error) {
} }
func LarkOAuth(c *gin.Context) { func LarkOAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state") state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
@@ -128,7 +125,7 @@ func LarkOAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(ctx, 0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -1,228 +0,0 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
)
type OidcResponse struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope"`
}
type OidcUser struct {
OpenID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
}
func getOidcUserInfoByCode(code string) (*OidcUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.OidcClientId,
"client_secret": config.OidcClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
defer res.Body.Close()
var oidcResponse OidcResponse
err = json.NewDecoder(res.Body).Decode(&oidcResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken)
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!")
}
var oidcUser OidcUser
err = json.NewDecoder(res2.Body).Decode(&oidcUser)
if err != nil {
return nil, err
}
return &oidcUser, nil
}
func OidcAuth(c *gin.Context) {
ctx := c.Request.Context()
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
OidcBind(c)
return
}
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
err := user.FillUserByOidcId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Email = oidcUser.Email
if oidcUser.PreferredUsername != "" {
user.Username = oidcUser.PreferredUsername
} else {
user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1)
}
if oidcUser.Name != "" {
user.DisplayName = oidcUser.Name
} else {
user.DisplayName = "OIDC User"
}
err := user.Insert(ctx, 0)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func OidcBind(c *gin.Context) {
if !config.OidcEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 OIDC 登录以及注册",
})
return
}
code := c.Query("code")
oidcUser, err := getOidcUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
OidcId: oidcUser.OpenID,
}
if model.IsOidcIdAlreadyTaken(user.OidcId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 OIDC 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.OidcId = oidcUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -4,16 +4,14 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
) )
type wechatLoginResponse struct { type wechatLoginResponse struct {
@@ -54,7 +52,6 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
ctx := c.Request.Context()
if !config.WeChatAuthEnabled { if !config.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
@@ -90,7 +87,7 @@ func WeChatAuth(c *gin.Context) {
user.Role = model.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = model.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(ctx, 0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -17,11 +17,9 @@ func GetSubscription(c *gin.Context) {
if config.DisplayTokenStatEnabled { if config.DisplayTokenStatEnabled {
tokenId := c.GetInt(ctxkey.TokenId) tokenId := c.GetInt(ctxkey.TokenId)
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
if err == nil { expiredTime = token.ExpiredTime
expiredTime = token.ExpiredTime remainQuota = token.RemainQuota
remainQuota = token.RemainQuota usedQuota = token.UsedQuota
usedQuota = token.UsedQuota
}
} else { } else {
userId := c.GetInt(ctxkey.Id) userId := c.GetInt(ctxkey.Id)
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)

View File

@@ -4,17 +4,16 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net/http"
"strconv"
"time"
"github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"io"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -82,36 +81,6 @@ type APGC2DGPTUsageResponse struct {
TotalUsed float64 `json:"total_used"` TotalUsed float64 `json:"total_used"`
} }
type SiliconFlowUsageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Status bool `json:"status"`
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Image string `json:"image"`
Email string `json:"email"`
IsAdmin bool `json:"isAdmin"`
Balance string `json:"balance"`
Status string `json:"status"`
Introduction string `json:"introduction"`
Role string `json:"role"`
ChargeBalance string `json:"chargeBalance"`
TotalBalance string `json:"totalBalance"`
Category string `json:"category"`
} `json:"data"`
}
type DeepSeekUsageResponse struct {
IsAvailable bool `json:"is_available"`
BalanceInfos []struct {
Currency string `json:"currency"`
TotalBalance string `json:"total_balance"`
GrantedBalance string `json:"granted_balance"`
ToppedUpBalance string `json:"topped_up_balance"`
} `json:"balance_infos"`
}
// GetAuthHeader get auth header // GetAuthHeader get auth header
func GetAuthHeader(token string) http.Header { func GetAuthHeader(token string) http.Header {
h := http.Header{} h := http.Header{}
@@ -234,57 +203,6 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
return response.TotalAvailable, nil return response.TotalAvailable, nil
} }
func updateChannelSiliconFlowBalance(channel *model.Channel) (float64, error) {
url := "https://api.siliconflow.cn/v1/user/info"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := SiliconFlowUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
if response.Code != 20000 {
return 0, fmt.Errorf("code: %d, message: %s", response.Code, response.Message)
}
balance, err := strconv.ParseFloat(response.Data.TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelDeepSeekBalance(channel *model.Channel) (float64, error) {
url := "https://api.deepseek.com/user/balance"
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil {
return 0, err
}
response := DeepSeekUsageResponse{}
err = json.Unmarshal(body, &response)
if err != nil {
return 0, err
}
index := -1
for i, balanceInfo := range response.BalanceInfos {
if balanceInfo.Currency == "CNY" {
index = i
break
}
}
if index == -1 {
return 0, errors.New("currency CNY not found")
}
balance, err := strconv.ParseFloat(response.BalanceInfos[index].TotalBalance, 64)
if err != nil {
return 0, err
}
channel.UpdateBalance(balance)
return balance, nil
}
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := channeltype.ChannelBaseURLs[channel.Type] baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
@@ -309,10 +227,6 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case channeltype.AIGC2D: case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
case channeltype.SiliconFlow:
return updateChannelSiliconFlowBalance(channel)
case channeltype.DeepSeek:
return updateChannelDeepSeekBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
} }

View File

@@ -2,7 +2,6 @@ package controller
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -16,17 +15,14 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
@@ -39,34 +35,18 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
model = "gpt-3.5-turbo" model = "gpt-3.5-turbo"
} }
testRequest := &relaymodel.GeneralOpenAIRequest{ testRequest := &relaymodel.GeneralOpenAIRequest{
Model: model, MaxTokens: 2,
Model: model,
} }
testMessage := relaymodel.Message{ testMessage := relaymodel.Message{
Role: "user", Role: "user",
Content: config.TestPrompt, Content: "hi",
} }
testRequest.Messages = append(testRequest.Messages, testMessage) testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest return testRequest
} }
func parseTestResponse(resp string) (*openai.TextResponse, string, error) { func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
var response openai.TextResponse
err := json.Unmarshal([]byte(resp), &response)
if err != nil {
return nil, "", err
}
if len(response.Choices) == 0 {
return nil, "", errors.New("response has no choices")
}
stringContent, ok := response.Choices[0].Content.(string)
if !ok {
return nil, "", errors.New("response content is not string")
}
return &response, stringContent, nil
}
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
startTime := time.Now()
w := httptest.NewRecorder() w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w) c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{ c.Request = &http.Request{
@@ -86,7 +66,7 @@ func testChannel(ctx context.Context, channel *model.Channel, request *relaymode
apiType := channeltype.ToAPIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
adaptor := relay.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
adaptor.Init(meta) adaptor.Init(meta)
modelName := request.Model modelName := request.Model
@@ -96,77 +76,49 @@ func testChannel(ctx context.Context, channel *model.Channel, request *relaymode
if len(modelNames) > 0 { if len(modelNames) > 0 {
modelName = modelNames[0] modelName = modelNames[0]
} }
} if modelMap != nil && modelMap[modelName] != "" {
if modelMap != nil && modelMap[modelName] != "" { modelName = modelMap[modelName]
modelName = modelMap[modelName] }
} }
meta.OriginModelName, meta.ActualModelName = request.Model, modelName meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName request.Model = modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return "", err, nil return err, nil
} }
jsonData, err := json.Marshal(convertedRequest) jsonData, err := json.Marshal(convertedRequest)
if err != nil { if err != nil {
return "", err, nil return err, nil
} }
defer func() {
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
if err != nil || openaiErr != nil {
errorMessage := ""
if err != nil {
errorMessage = err.Error()
} else {
errorMessage = openaiErr.Message
}
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
}
go model.RecordTestLog(ctx, &model.Log{
ChannelId: channel.Id,
ModelName: modelName,
Content: logContent,
ElapsedTime: helper.CalcElapsedTime(startTime),
})
}()
logger.SysLog(string(jsonData)) logger.SysLog(string(jsonData))
requestBody := bytes.NewBuffer(jsonData) requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody) c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody) resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil { if err != nil {
return "", err, nil return err, nil
} }
if resp != nil && resp.StatusCode != http.StatusOK { if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp) err := controller.RelayErrorHandler(resp)
errorMessage := err.Error.Message return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
if errorMessage != "" {
errorMessage = ", error message: " + errorMessage
}
return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
} }
if usage == nil { if usage == nil {
return "", errors.New("usage is nil"), nil return errors.New("usage is nil"), nil
}
rawResponse := w.Body.String()
_, responseMessage, err = parseTestResponse(rawResponse)
if err != nil {
return "", err, nil
} }
result := w.Result() result := w.Result()
// print result.Body // print result.Body
respBody, err := io.ReadAll(result.Body) respBody, err := io.ReadAll(result.Body)
if err != nil { if err != nil {
return "", err, nil return err, nil
} }
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return responseMessage, nil, nil return nil, nil
} }
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
ctx := c.Request.Context()
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -183,10 +135,10 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
modelName := c.Query("model") model := c.Query("model")
testRequest := buildTestRequest(modelName) testRequest := buildTestRequest(model)
tik := time.Now() tik := time.Now()
responseMessage, err, _ := testChannel(ctx, channel, testRequest) err, _ = testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if err != nil { if err != nil {
@@ -196,18 +148,18 @@ func TestChannel(c *gin.Context) {
consumedTime := float64(milliseconds) / 1000.0 consumedTime := float64(milliseconds) / 1000.0
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
"time": consumedTime, "time": consumedTime,
"modelName": modelName, "model": model,
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": responseMessage, "message": "",
"time": consumedTime, "time": consumedTime,
"modelName": modelName, "model": model,
}) })
return return
} }
@@ -215,7 +167,7 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func testChannels(ctx context.Context, notify bool, scope string) error { func testChannels(notify bool, scope string) error {
if config.RootUserEmail == "" { if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail() config.RootUserEmail = model.GetRootUserEmail()
} }
@@ -239,7 +191,7 @@ func testChannels(ctx context.Context, notify bool, scope string) error {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
testRequest := buildTestRequest("") testRequest := buildTestRequest("")
_, err, openaiErr := testChannel(ctx, channel, testRequest) err, openaiErr := testChannel(channel, testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if isChannelEnabled && milliseconds > disableThreshold {
@@ -273,12 +225,11 @@ func testChannels(ctx context.Context, notify bool, scope string) error {
} }
func TestChannels(c *gin.Context) { func TestChannels(c *gin.Context) {
ctx := c.Request.Context()
scope := c.Query("scope") scope := c.Query("scope")
if scope == "" { if scope == "" {
scope = "all" scope = "all"
} }
err := testChannels(ctx, true, scope) err := testChannels(true, scope)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -294,11 +245,10 @@ func TestChannels(c *gin.Context) {
} }
func AutomaticallyTestChannels(frequency int) { func AutomaticallyTestChannels(frequency int) {
ctx := context.Background()
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("testing all channels") logger.SysLog("testing all channels")
_ = testChannels(ctx, false, "all") _ = testChannels(false, "all")
logger.SysLog("channel test finished") logger.SysLog("channel test finished")
} }
} }

View File

@@ -18,30 +18,24 @@ func GetStatus(c *gin.Context) {
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
"version": common.Version, "version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled, "email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled, "github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId, "github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId, "lark_client_id": config.LarkClientId,
"system_name": config.SystemName, "system_name": config.SystemName,
"logo": config.Logo, "logo": config.Logo,
"footer_html": config.Footer, "footer_html": config.Footer,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL, "wechat_qrcode": config.WeChatAccountQRCodeImageURL,
"wechat_login": config.WeChatAuthEnabled, "wechat_login": config.WeChatAuthEnabled,
"server_address": config.ServerAddress, "server_address": config.ServerAddress,
"turnstile_check": config.TurnstileCheckEnabled, "turnstile_check": config.TurnstileCheckEnabled,
"turnstile_site_key": config.TurnstileSiteKey, "turnstile_site_key": config.TurnstileSiteKey,
"top_up_link": config.TopUpLink, "top_up_link": config.TopUpLink,
"chat_link": config.ChatLink, "chat_link": config.ChatLink,
"quota_per_unit": config.QuotaPerUnit, "quota_per_unit": config.QuotaPerUnit,
"display_in_currency": config.DisplayInCurrencyEnabled, "display_in_currency": config.DisplayInCurrencyEnabled,
"oidc": config.OidcEnabled,
"oidc_client_id": config.OidcClientId,
"oidc_well_known": config.OidcWellKnown,
"oidc_authorization_endpoint": config.OidcAuthorizationEndpoint,
"oidc_token_endpoint": config.OidcTokenEndpoint,
"oidc_userinfo_endpoint": config.OidcUserinfoEndpoint,
}, },
}) })
return return

View File

@@ -60,7 +60,7 @@ func Relay(c *gin.Context) {
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
group := c.GetString(ctxkey.Group) group := c.GetString(ctxkey.Group)
originalModel := c.GetString(ctxkey.OriginalModel) originalModel := c.GetString(ctxkey.OriginalModel)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
requestId := c.GetString(helper.RequestIdKey) requestId := c.GetString(helper.RequestIdKey)
retryTimes := config.RetryTimes retryTimes := config.RetryTimes
if !shouldRetry(c, bizErr.StatusCode) { if !shouldRetry(c, bizErr.StatusCode) {
@@ -87,7 +87,8 @@ func Relay(c *gin.Context) {
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
lastFailedChannelId = channelId lastFailedChannelId = channelId
channelName := c.GetString(ctxkey.ChannelName) channelName := c.GetString(ctxkey.ChannelName)
go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) // BUG: bizErr is in race condition
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
} }
if bizErr != nil { if bizErr != nil {
if bizErr.StatusCode == http.StatusTooManyRequests { if bizErr.StatusCode == http.StatusTooManyRequests {
@@ -121,7 +122,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
return true return true
} }
func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err model.ErrorWithStatusCode) { func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {

View File

@@ -109,7 +109,6 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
ctx := c.Request.Context()
if !config.RegisterEnabled { if !config.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
@@ -167,7 +166,7 @@ func Register(c *gin.Context) {
if config.EmailVerificationEnabled { if config.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(ctx, inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -363,7 +362,6 @@ func GetSelf(c *gin.Context) {
} }
func UpdateUser(c *gin.Context) { func UpdateUser(c *gin.Context) {
ctx := c.Request.Context()
var updatedUser model.User var updatedUser model.User
err := json.NewDecoder(c.Request.Body).Decode(&updatedUser) err := json.NewDecoder(c.Request.Body).Decode(&updatedUser)
if err != nil || updatedUser.Id == 0 { if err != nil || updatedUser.Id == 0 {
@@ -418,7 +416,7 @@ func UpdateUser(c *gin.Context) {
return return
} }
if originUser.Quota != updatedUser.Quota { if originUser.Quota != updatedUser.Quota {
model.RecordLog(ctx, originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota))) model.RecordLog(originUser.Id, model.LogTypeManage, fmt.Sprintf("管理员将用户额度从 %s修改为 %s", common.LogQuota(originUser.Quota), common.LogQuota(updatedUser.Quota)))
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
@@ -537,7 +535,6 @@ func DeleteSelf(c *gin.Context) {
} }
func CreateUser(c *gin.Context) { func CreateUser(c *gin.Context) {
ctx := c.Request.Context()
var user model.User var user model.User
err := json.NewDecoder(c.Request.Body).Decode(&user) err := json.NewDecoder(c.Request.Body).Decode(&user)
if err != nil || user.Username == "" || user.Password == "" { if err != nil || user.Username == "" || user.Password == "" {
@@ -571,7 +568,7 @@ func CreateUser(c *gin.Context) {
Password: user.Password, Password: user.Password,
DisplayName: user.DisplayName, DisplayName: user.DisplayName,
} }
if err := cleanUser.Insert(ctx, 0); err != nil { if err := cleanUser.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -750,7 +747,6 @@ type topUpRequest struct {
} }
func TopUp(c *gin.Context) { func TopUp(c *gin.Context) {
ctx := c.Request.Context()
req := topUpRequest{} req := topUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@@ -761,7 +757,7 @@ func TopUp(c *gin.Context) {
return return
} }
id := c.GetInt("id") id := c.GetInt("id")
quota, err := model.Redeem(ctx, req.Key, id) quota, err := model.Redeem(req.Key, id)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -784,7 +780,6 @@ type adminTopUpRequest struct {
} }
func AdminTopUp(c *gin.Context) { func AdminTopUp(c *gin.Context) {
ctx := c.Request.Context()
req := adminTopUpRequest{} req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
if err != nil { if err != nil {
@@ -805,7 +800,7 @@ func AdminTopUp(c *gin.Context) {
if req.Remark == "" { if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
} }
model.RecordTopupLog(ctx, req.UserId, req.Remark, req.Quota) model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",

8
go.mod
View File

@@ -25,7 +25,7 @@ require (
github.com/pkoukk/tiktoken-go v0.1.7 github.com/pkoukk/tiktoken-go v0.1.7
github.com/smartystreets/goconvey v1.8.1 github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.31.0 golang.org/x/crypto v0.24.0
golang.org/x/image v0.18.0 golang.org/x/image v0.18.0
google.golang.org/api v0.187.0 google.golang.org/api v0.187.0
gorm.io/driver/mysql v1.5.6 gorm.io/driver/mysql v1.5.6
@@ -99,9 +99,9 @@ require (
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/net v0.26.0 // indirect golang.org/x/net v0.26.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.28.0 // indirect golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect

16
go.sum
View File

@@ -222,8 +222,8 @@ golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ=
golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E=
@@ -244,20 +244,20 @@ golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbht
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -2,24 +2,21 @@ package middleware
import ( import (
"fmt" "fmt"
"net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"net/http"
"strconv"
) )
type ModelRequest struct { type ModelRequest struct {
Model string `json:"model" form:"model"` Model string `json:"model"`
} }
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
ctx := c.Request.Context()
userId := c.GetInt(ctxkey.Id) userId := c.GetInt(ctxkey.Id)
userGroup, _ := model.CacheGetUserGroup(userId) userGroup, _ := model.CacheGetUserGroup(userId)
c.Set(ctxkey.Group, userGroup) c.Set(ctxkey.Group, userGroup)
@@ -55,7 +52,6 @@ func Distribute() func(c *gin.Context) {
return return
} }
} }
logger.Debugf(ctx, "user id %d, user group: %s, request model: %s, using channel #%d", userId, userGroup, requestModel, channel.Id)
SetupContextForSelectedChannel(c, channel, requestModel) SetupContextForSelectedChannel(c, channel, requestModel)
c.Next() c.Next()
} }
@@ -65,9 +61,6 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set(ctxkey.Channel, channel.Type) c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.ChannelId, channel.Id) c.Set(ctxkey.ChannelId, channel.Id)
c.Set(ctxkey.ChannelName, channel.Name) c.Set(ctxkey.ChannelName, channel.Name)
if channel.SystemPrompt != nil && *channel.SystemPrompt != "" {
c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt)
}
c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.ModelMapping, channel.GetModelMapping())
c.Set(ctxkey.OriginalModel, modelName) // for retry c.Set(ctxkey.OriginalModel, modelName) // for retry
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))

View File

@@ -1,27 +0,0 @@
package middleware
import (
"compress/gzip"
"github.com/gin-gonic/gin"
"io"
"net/http"
)
func GzipDecodeMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if c.GetHeader("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(c.Request.Body)
if err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
defer gzipReader.Close()
// Replace the request body with the decompressed data
c.Request.Body = io.NopCloser(gzipReader)
}
// Continue processing the request
c.Next()
}
}

View File

@@ -1,8 +1,8 @@
package middleware package middleware
import ( import (
"context"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
) )
@@ -10,7 +10,7 @@ func RequestId() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
id := helper.GenRequestID() id := helper.GenRequestID()
c.Set(helper.RequestIdKey, id) c.Set(helper.RequestIdKey, id)
ctx := helper.SetRequestID(c.Request.Context(), id) ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx) c.Request = c.Request.WithContext(ctx)
c.Header(helper.RequestIdKey, id) c.Header(helper.RequestIdKey, id)
c.Next() c.Next()

View File

@@ -37,7 +37,6 @@ type Channel struct {
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` Priority *int64 `json:"priority" gorm:"bigint;default:0"`
Config string `json:"config"` Config string `json:"config"`
SystemPrompt *string `json:"system_prompt" gorm:"type:text"`
} }
type ChannelConfig struct { type ChannelConfig struct {

View File

@@ -3,32 +3,26 @@ package model
import ( import (
"context" "context"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
) )
type Log struct { type Log struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id" gorm:"index"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
Type int `json:"type" gorm:"index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"`
Content string `json:"content"` Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
Quota int `json:"quota" gorm:"default:0"` Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"` ChannelId int `json:"channel" gorm:"index"`
RequestId string `json:"request_id" gorm:"default:''"`
ElapsedTime int64 `json:"elapsed_time" gorm:"default:0"` // unit is ms
IsStream bool `json:"is_stream" gorm:"default:false"`
SystemPromptReset bool `json:"system_prompt_reset" gorm:"default:false"`
} }
const ( const (
@@ -37,21 +31,9 @@ const (
LogTypeConsume LogTypeConsume
LogTypeManage LogTypeManage
LogTypeSystem LogTypeSystem
LogTypeTest
) )
func recordLogHelper(ctx context.Context, log *Log) { func RecordLog(userId int, logType int, content string) {
requestId := helper.GetRequestID(ctx)
log.RequestId = requestId
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
return
}
logger.Infof(ctx, "record log: %+v", log)
}
func RecordLog(ctx context.Context, userId int, logType int, content string) {
if logType == LogTypeConsume && !config.LogConsumeEnabled { if logType == LogTypeConsume && !config.LogConsumeEnabled {
return return
} }
@@ -62,10 +44,13 @@ func RecordLog(ctx context.Context, userId int, logType int, content string) {
Type: logType, Type: logType,
Content: content, Content: content,
} }
recordLogHelper(ctx, log) err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordTopupLog(ctx context.Context, userId int, content string, quota int) { func RecordTopupLog(userId int, content string, quota int) {
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
@@ -74,23 +59,34 @@ func RecordTopupLog(ctx context.Context, userId int, content string, quota int)
Content: content, Content: content,
Quota: quota, Quota: quota,
} }
recordLogHelper(ctx, log) err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
} }
func RecordConsumeLog(ctx context.Context, log *Log) { func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled { if !config.LogConsumeEnabled {
return return
} }
log.Username = GetUsernameById(log.UserId) log := &Log{
log.CreatedAt = helper.GetTimestamp() UserId: userId,
log.Type = LogTypeConsume Username: GetUsernameById(userId),
recordLogHelper(ctx, log) CreatedAt: helper.GetTimestamp(),
} Type: LogTypeConsume,
Content: content,
func RecordTestLog(ctx context.Context, log *Log) { PromptTokens: promptTokens,
log.CreatedAt = helper.GetTimestamp() CompletionTokens: completionTokens,
log.Type = LogTypeTest TokenName: tokenName,
recordLogHelper(ctx, log) ModelName: modelName,
Quota: int(quota),
ChannelId: channelId,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
}
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
@@ -156,11 +152,7 @@ func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
ifnull := "ifnull" tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(quota),0)", ifnull))
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -184,11 +176,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
ifnull := "ifnull" tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if common.UsingPostgreSQL {
ifnull = "COALESCE"
}
tx := LOG_DB.Table("logs").Select(fmt.Sprintf("%s(sum(prompt_tokens),0) + %s(sum(completion_tokens),0)", ifnull, ifnull))
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }

View File

@@ -28,7 +28,6 @@ func InitOptionMap() {
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled)
config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled)
@@ -131,8 +130,6 @@ func updateOptionMap(key string, value string) (err error) {
config.EmailVerificationEnabled = boolValue config.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue config.GitHubOAuthEnabled = boolValue
case "OidcEnabled":
config.OidcEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue config.WeChatAuthEnabled = boolValue
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
@@ -179,18 +176,6 @@ func updateOptionMap(key string, value string) (err error) {
config.LarkClientId = value config.LarkClientId = value
case "LarkClientSecret": case "LarkClientSecret":
config.LarkClientSecret = value config.LarkClientSecret = value
case "OidcClientId":
config.OidcClientId = value
case "OidcClientSecret":
config.OidcClientSecret = value
case "OidcWellKnown":
config.OidcWellKnown = value
case "OidcAuthorizationEndpoint":
config.OidcAuthorizationEndpoint = value
case "OidcTokenEndpoint":
config.OidcTokenEndpoint = value
case "OidcUserinfoEndpoint":
config.OidcUserinfoEndpoint = value
case "Footer": case "Footer":
config.Footer = value config.Footer = value
case "SystemName": case "SystemName":

View File

@@ -1,14 +1,11 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm"
) )
const ( const (
@@ -51,7 +48,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err return &redemption, err
} }
func Redeem(ctx context.Context, key string, userId int) (quota int64, err error) { func Redeem(key string, userId int) (quota int64, err error) {
if key == "" { if key == "" {
return 0, errors.New("未提供兑换码") return 0, errors.New("未提供兑换码")
} }
@@ -85,7 +82,7 @@ func Redeem(ctx context.Context, key string, userId int) (quota int64, err error
if err != nil { if err != nil {
return 0, errors.New("兑换失败," + err.Error()) return 0, errors.New("兑换失败," + err.Error())
} }
RecordLog(ctx, userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota))) RecordLog(userId, LogTypeTopup, fmt.Sprintf("通过兑换码充值 %s", common.LogQuota(redemption.Quota)))
return redemption.Quota, nil return redemption.Quota, nil
} }

View File

@@ -30,7 +30,7 @@ type Token struct {
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"type:text"` // allowed models Models *string `json:"models" gorm:"default:''"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
} }
@@ -121,40 +121,30 @@ func GetTokenById(id int) (*Token, error) {
return &token, err return &token, err
} }
func (t *Token) Insert() error { func (token *Token) Insert() error {
var err error var err error
err = DB.Create(t).Error err = DB.Create(token).Error
return err return err
} }
// Update Make sure your token's fields is completed, because this will update non-zero values // Update Make sure your token's fields is completed, because this will update non-zero values
func (t *Token) Update() error { func (token *Token) Update() error {
var err error var err error
err = DB.Model(t).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(t).Error err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err return err
} }
func (t *Token) SelectUpdate() error { func (token *Token) SelectUpdate() error {
// This can update zero values // This can update zero values
return DB.Model(t).Select("accessed_time", "status").Updates(t).Error return DB.Model(token).Select("accessed_time", "status").Updates(token).Error
} }
func (t *Token) Delete() error { func (token *Token) Delete() error {
var err error var err error
err = DB.Delete(t).Error err = DB.Delete(token).Error
return err return err
} }
func (t *Token) GetModels() string {
if t == nil {
return ""
}
if t.Models == nil {
return ""
}
return *t.Models
}
func DeleteTokenById(id int, userId int) (err error) { func DeleteTokenById(id int, userId int) (err error) {
// Why we need userId here? In case user want to delete other's token. // Why we need userId here? In case user want to delete other's token.
if id == 0 || userId == 0 { if id == 0 || userId == 0 {
@@ -264,14 +254,14 @@ func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) { func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId) token, err := GetTokenById(tokenId)
if err != nil {
return err
}
if quota > 0 { if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota) err = DecreaseUserQuota(token.UserId, quota)
} else { } else {
err = IncreaseUserQuota(token.UserId, -quota) err = IncreaseUserQuota(token.UserId, -quota)
} }
if err != nil {
return err
}
if !token.UnlimitedQuota { if !token.UnlimitedQuota {
if quota > 0 { if quota > 0 {
err = DecreaseTokenQuota(tokenId, quota) err = DecreaseTokenQuota(tokenId, quota)

View File

@@ -1,19 +1,16 @@
package model package model
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings"
"gorm.io/gorm"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
"strings"
) )
const ( const (
@@ -42,7 +39,6 @@ type User struct {
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"` Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -95,7 +91,7 @@ func GetUserById(id int, selectAll bool) (*User, error) {
if selectAll { if selectAll {
err = DB.First(&user, "id = ?", id).Error err = DB.First(&user, "id = ?", id).Error
} else { } else {
err = DB.Omit("password", "access_token").First(&user, "id = ?", id).Error err = DB.Omit("password").First(&user, "id = ?", id).Error
} }
return &user, err return &user, err
} }
@@ -117,7 +113,7 @@ func DeleteUserById(id int) (err error) {
return user.Delete() return user.Delete()
} }
func (user *User) Insert(ctx context.Context, inviterId int) error { func (user *User) Insert(inviterId int) error {
var err error var err error
if user.Password != "" { if user.Password != "" {
user.Password, err = common.Password2Hash(user.Password) user.Password, err = common.Password2Hash(user.Password)
@@ -133,16 +129,16 @@ func (user *User) Insert(ctx context.Context, inviterId int) error {
return result.Error return result.Error
} }
if config.QuotaForNewUser > 0 { if config.QuotaForNewUser > 0 {
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if config.QuotaForInvitee > 0 { if config.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
RecordLog(ctx, user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
} }
if config.QuotaForInviter > 0 { if config.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) _ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
RecordLog(ctx, inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
} }
} }
// create default token // create default token
@@ -249,14 +245,6 @@ func (user *User) FillUserByLarkId() error {
return nil return nil
} }
func (user *User) FillUserByOidcId() error {
if user.OidcId == "" {
return errors.New("oidc id 为空!")
}
DB.Where(User{OidcId: user.OidcId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -289,10 +277,6 @@ func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
} }
func IsOidcIdAlreadyTaken(oidcId string) bool {
return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool { func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
} }

View File

@@ -1,11 +1,10 @@
package monitor package monitor
import ( import (
"net/http"
"strings"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
) )
func ShouldDisableChannel(err *model.Error, statusCode int) bool { func ShouldDisableChannel(err *model.Error, statusCode int) bool {
@@ -19,23 +18,31 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool {
return true return true
} }
switch err.Type { switch err.Type {
case "insufficient_quota", "authentication_error", "permission_error", "forbidden": case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true return true
} }
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true return true
} }
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
lowerMessage := strings.ToLower(err.Message) return true
if strings.Contains(lowerMessage, "your access was terminated") || } else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
strings.Contains(lowerMessage, "violation of our policies") || return true
strings.Contains(lowerMessage, "your credit balance is too low") || }
strings.Contains(lowerMessage, "organization has been disabled") || //if strings.Contains(err.Message, "quota") {
strings.Contains(lowerMessage, "credit") || // return true
strings.Contains(lowerMessage, "balance") || //}
strings.Contains(lowerMessage, "permission denied") || if strings.Contains(err.Message, "credit") {
strings.Contains(lowerMessage, "organization has been restricted") || // groq return true
strings.Contains(lowerMessage, "已欠费") { }
if strings.Contains(err.Message, "balance") {
return true return true
} }
return false return false

View File

@@ -16,7 +16,6 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/proxy" "github.com/songquanpeng/one-api/relay/adaptor/proxy"
"github.com/songquanpeng/one-api/relay/adaptor/replicate"
"github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/vertexai"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei" "github.com/songquanpeng/one-api/relay/adaptor/xunfei"
@@ -62,8 +61,6 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
return &vertexai.Adaptor{} return &vertexai.Adaptor{}
case apitype.Proxy: case apitype.Proxy:
return &proxy.Adaptor{} return &proxy.Adaptor{}
case apitype.Replicate:
return &replicate.Adaptor{}
} }
return nil return nil
} }

View File

@@ -1,23 +1,7 @@
package ali package ali
var ModelList = []string{ var ModelList = []string{
"qwen-turbo", "qwen-turbo-latest", "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"qwen-plus", "qwen-plus-latest", "text-embedding-v1",
"qwen-max", "qwen-max-latest",
"qwen-max-longcontext",
"qwen-vl-max", "qwen-vl-max-latest", "qwen-vl-plus", "qwen-vl-plus-latest",
"qwen-vl-ocr", "qwen-vl-ocr-latest",
"qwen-audio-turbo",
"qwen-math-plus", "qwen-math-plus-latest", "qwen-math-turbo", "qwen-math-turbo-latest",
"qwen-coder-plus", "qwen-coder-plus-latest", "qwen-coder-turbo", "qwen-coder-turbo-latest",
"qwq-32b-preview", "qwen2.5-72b-instruct", "qwen2.5-32b-instruct", "qwen2.5-14b-instruct", "qwen2.5-7b-instruct", "qwen2.5-3b-instruct", "qwen2.5-1.5b-instruct", "qwen2.5-0.5b-instruct",
"qwen2-72b-instruct", "qwen2-57b-a14b-instruct", "qwen2-7b-instruct", "qwen2-1.5b-instruct", "qwen2-0.5b-instruct",
"qwen1.5-110b-chat", "qwen1.5-72b-chat", "qwen1.5-32b-chat", "qwen1.5-14b-chat", "qwen1.5-7b-chat", "qwen1.5-1.8b-chat", "qwen1.5-0.5b-chat",
"qwen-72b-chat", "qwen-14b-chat", "qwen-7b-chat", "qwen-1.8b-chat", "qwen-1.8b-longcontext-chat",
"qwen2-vl-7b-instruct", "qwen2-vl-2b-instruct", "qwen-vl-v1", "qwen-vl-chat-v1",
"qwen2-audio-instruct", "qwen-audio-chat",
"qwen2.5-math-72b-instruct", "qwen2.5-math-7b-instruct", "qwen2.5-math-1.5b-instruct", "qwen2-math-72b-instruct", "qwen2-math-7b-instruct", "qwen2-math-1.5b-instruct",
"qwen2.5-coder-32b-instruct", "qwen2.5-coder-14b-instruct", "qwen2.5-coder-7b-instruct", "qwen2.5-coder-3b-instruct", "qwen2.5-coder-1.5b-instruct", "qwen2.5-coder-0.5b-instruct",
"text-embedding-v1", "text-embedding-v3", "text-embedding-v2", "text-embedding-async-v2", "text-embedding-async-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
} }

View File

@@ -3,7 +3,6 @@ package ali
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
@@ -36,7 +35,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
enableSearch = true enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
} }
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999) if request.TopP >= 1 {
request.TopP = 0.9999
}
return &ChatRequest{ return &ChatRequest{
Model: aliModel, Model: aliModel,
Input: Input{ Input: Input{
@@ -58,7 +59,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{ return &EmbeddingRequest{
Model: request.Model, Model: "text-embedding-v1",
Input: struct { Input: struct {
Texts []string `json:"texts"` Texts []string `json:"texts"`
}{ }{
@@ -101,9 +102,8 @@ func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStat
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse) fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -16,13 +16,13 @@ type Input struct {
} }
type Parameters struct { type Parameters struct {
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"` Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"` EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"` IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"` ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"` Tools []model.Tool `json:"tools,omitempty"`
} }

View File

@@ -3,10 +3,7 @@ package anthropic
var ModelList = []string{ var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest",
} }

View File

@@ -48,8 +48,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`

View File

@@ -29,13 +29,10 @@ var AwsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1", "claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2", "claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1", "claude-2.1": "anthropic.claude-v2:1",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0", "claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
} }
func awsModelID(requestModel string) (string, error) { func awsModelID(requestModel string) (string, error) {

View File

@@ -11,8 +11,8 @@ type Request struct {
Messages []anthropic.Message `json:"messages"` Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"`

View File

@@ -4,10 +4,10 @@ package aws
// //
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html // https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct { type Request struct {
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"` MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
} }
// Response is the response from AWS Llama3 // Response is the response from AWS Llama3

View File

@@ -35,9 +35,9 @@ type Message struct {
type ChatRequest struct { type ChatRequest struct {
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
PenaltyScore *float64 `json:"penalty_score,omitempty"` PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"` DisableSearch bool `json:"disable_search,omitempty"`

View File

@@ -1,7 +1,6 @@
package cloudflare package cloudflare
var ModelList = []string{ var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16", "@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8", "@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1", "@cf/mistral/mistral-7b-instruct-v0.1",

View File

@@ -9,5 +9,5 @@ type Request struct {
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
} }

View File

@@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
K: textRequest.TopK, K: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
FrequencyPenalty: textRequest.FrequencyPenalty, FrequencyPenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.PresencePenalty, PresencePenalty: textRequest.FrequencyPenalty,
Seed: int(textRequest.Seed), Seed: int(textRequest.Seed),
} }
if cohereRequest.Model == "" { if cohereRequest.Model == "" {

View File

@@ -10,15 +10,15 @@ type Request struct {
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
Connectors []Connector `json:"connectors,omitempty"` Connectors []Connector `json:"connectors,omitempty"`
Documents []Document `json:"documents,omitempty"` Documents []Document `json:"documents,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3 Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"` MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"` // 默认值为0 K int `json:"k,omitempty"` // 默认值为0
P *float64 `json:"p,omitempty"` // 默认值为0.75 P float64 `json:"p,omitempty"` // 默认值为0.75
Seed int `json:"seed,omitempty"` Seed int `json:"seed,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0 FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0 PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolResults []ToolResult `json:"tool_results,omitempty"` ToolResults []ToolResult `json:"tool_results,omitempty"`
} }

View File

@@ -7,12 +7,8 @@ import (
) )
func GetRequestURL(meta *meta.Meta) (string, error) { func GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.Mode { if meta.Mode == relaymode.ChatCompletions {
case relaymode.ChatCompletions:
return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/api/v3/chat/completions", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/v3/embeddings", meta.BaseURL), nil
default:
} }
return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode) return "", fmt.Errorf("unsupported relay mode %d for doubao", meta.Mode)
} }

View File

@@ -7,6 +7,7 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor" channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
@@ -23,15 +24,7 @@ func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
var defaultVersion string version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion)
switch meta.ActualModelName {
case "gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp",
"gemini-2.0-flash-thinking-exp-01-21":
defaultVersion = "v1beta"
}
version := helper.AssignOrDefault(meta.Config.APIVersion, defaultVersion)
action := "" action := ""
switch meta.Mode { switch meta.Mode {
case relaymode.Embeddings: case relaymode.Embeddings:
@@ -43,7 +36,6 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if meta.IsStream { if meta.IsStream {
action = "streamGenerateContent?alt=sse" action = "streamGenerateContent?alt=sse"
} }
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
} }

View File

@@ -3,9 +3,6 @@ package gemini
// https://ai.google.dev/models/gemini // https://ai.google.dev/models/gemini
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-1.0-pro", "gemini-pro", "gemini-1.0-pro-001", "gemini-1.5-pro",
"gemini-1.5-flash", "gemini-1.5-pro", "gemini-pro-vision", "gemini-1.0-pro-vision-001", "embedding-001", "text-embedding-004",
"text-embedding-004", "aqa",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
} }

View File

@@ -4,12 +4,11 @@ import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
@@ -29,11 +28,6 @@ const (
VisionMaxImageNum = 16 VisionMaxImageNum = 16
) )
var mimeTypeMap = map[string]string{
"json_object": "application/json",
"text": "text/plain",
}
// Setting safety to the lowest possible values since Gemini is already powerless enough // Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{ geminiRequest := ChatRequest{
@@ -55,10 +49,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
Category: "HARM_CATEGORY_DANGEROUS_CONTENT", Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting, Threshold: config.GeminiSafetySetting,
}, },
{
Category: "HARM_CATEGORY_CIVIC_INTEGRITY",
Threshold: config.GeminiSafetySetting,
},
}, },
GenerationConfig: ChatGenerationConfig{ GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
@@ -66,15 +56,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.MaxTokens,
}, },
} }
if textRequest.ResponseFormat != nil {
if mimeType, ok := mimeTypeMap[textRequest.ResponseFormat.Type]; ok {
geminiRequest.GenerationConfig.ResponseMimeType = mimeType
}
if textRequest.ResponseFormat.JsonSchema != nil {
geminiRequest.GenerationConfig.ResponseSchema = textRequest.ResponseFormat.JsonSchema.Schema
geminiRequest.GenerationConfig.ResponseMimeType = mimeTypeMap["json_object"]
}
}
if textRequest.Tools != nil { if textRequest.Tools != nil {
functions := make([]model.Function, 0, len(textRequest.Tools)) functions := make([]model.Function, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@@ -251,14 +232,7 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
if candidate.Content.Parts[0].FunctionCall != nil { if candidate.Content.Parts[0].FunctionCall != nil {
choice.Message.ToolCalls = getToolCalls(&candidate) choice.Message.ToolCalls = getToolCalls(&candidate)
} else { } else {
var builder strings.Builder choice.Message.Content = candidate.Content.Parts[0].Text
for _, part := range candidate.Content.Parts {
if i > 0 {
builder.WriteString("\n")
}
builder.WriteString(part.Text)
}
choice.Message.Content = builder.String()
} }
} else { } else {
choice.Message.Content = "" choice.Message.Content = ""

View File

@@ -65,12 +65,10 @@ type ChatTools struct {
} }
type ChatGenerationConfig struct { type ChatGenerationConfig struct {
ResponseMimeType string `json:"responseMimeType,omitempty"` Temperature float64 `json:"temperature,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"` TopP float64 `json:"topP,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` TopK float64 `json:"topK,omitempty"`
TopP *float64 `json:"topP,omitempty"` MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
TopK float64 `json:"topK,omitempty"` CandidateCount int `json:"candidateCount,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` StopSequences []string `json:"stopSequences,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
} }

View File

@@ -4,24 +4,9 @@ package groq
var ModelList = []string{ var ModelList = []string{
"gemma-7b-it", "gemma-7b-it",
"gemma2-9b-it", "llama2-7b-2048",
"llama-3.1-70b-versatile", "llama2-70b-4096",
"llama-3.1-8b-instant",
"llama-3.2-11b-text-preview",
"llama-3.2-11b-vision-preview",
"llama-3.2-1b-preview",
"llama-3.2-3b-preview",
"llama-3.2-11b-vision-preview",
"llama-3.2-90b-text-preview",
"llama-3.2-90b-vision-preview",
"llama-guard-3-8b",
"llama3-70b-8192",
"llama3-8b-8192",
"llama3-groq-70b-8192-tool-use-preview",
"llama3-groq-8b-8192-tool-use-preview",
"llava-v1.5-7b-4096-preview",
"mixtral-8x7b-32768", "mixtral-8x7b-32768",
"distil-whisper-large-v3-en", "llama3-8b-8192",
"whisper-large-v3", "llama3-70b-8192",
"whisper-large-v3-turbo",
} }

View File

@@ -24,7 +24,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == relaymode.Embeddings { if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embed", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }

View File

@@ -31,8 +31,6 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
TopP: request.TopP, TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty, FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty, PresencePenalty: request.PresencePenalty,
NumPredict: request.MaxTokens,
NumCtx: request.NumCtx,
}, },
Stream: request.Stream, Stream: request.Stream,
} }
@@ -120,10 +118,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := strings.TrimPrefix(scanner.Text(), "}")
if strings.HasPrefix(data, "}") { data = data + "}"
data = strings.TrimPrefix(data, "}") + "}"
}
var ollamaResponse ChatResponse var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse) err := json.Unmarshal([]byte(data), &ollamaResponse)
@@ -161,15 +157,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{ return &EmbeddingRequest{
Model: request.Model, Model: request.Model,
Input: request.ParseInput(), Prompt: strings.Join(request.ParseInput(), " "),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
} }
} }
@@ -212,17 +201,15 @@ func embeddingResponseOllama2OpenAI(response *EmbeddingResponse) *openai.Embeddi
openAIEmbeddingResponse := openai.EmbeddingResponse{ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list", Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, 1), Data: make([]openai.EmbeddingResponseItem, 0, 1),
Model: response.Model, Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: 0}, Usage: model.Usage{TotalTokens: 0},
} }
for i, embedding := range response.Embeddings { openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{ Object: `embedding`,
Object: `embedding`, Index: 0,
Index: i, Embedding: response.Embedding,
Embedding: embedding, })
})
}
return &openAIEmbeddingResponse return &openAIEmbeddingResponse
} }

View File

@@ -1,14 +1,12 @@
package ollama package ollama
type Options struct { type Options struct {
Seed int `json:"seed,omitempty"` Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
} }
type Message struct { type Message struct {
@@ -39,15 +37,11 @@ type ChatResponse struct {
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input []string `json:"input"` Prompt string `json:"prompt"`
// Truncate bool `json:"truncate,omitempty"`
Options *Options `json:"options,omitempty"`
// KeepAlive string `json:"keep_alive,omitempty"`
} }
type EmbeddingResponse struct { type EmbeddingResponse struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`
Model string `json:"model"` Embedding []float64 `json:"embedding,omitempty"`
Embeddings [][]float64 `json:"embeddings"`
} }

View File

@@ -75,13 +75,6 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if request.Stream {
// always return usage in stream mode
if request.StreamOptions == nil {
request.StreamOptions = &model.StreamOptions{}
}
request.StreamOptions.IncludeUsage = true
}
return request, nil return request, nil
} }

View File

@@ -11,10 +11,8 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/mistral" "github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot" "github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/siliconflow"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun" "github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/adaptor/togetherai" "github.com/songquanpeng/one-api/relay/adaptor/togetherai"
"github.com/songquanpeng/one-api/relay/adaptor/xai"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
) )
@@ -32,8 +30,6 @@ var CompatibleChannels = []int{
channeltype.DeepSeek, channeltype.DeepSeek,
channeltype.TogetherAI, channeltype.TogetherAI,
channeltype.Novita, channeltype.Novita,
channeltype.SiliconFlow,
channeltype.XAI,
} }
func GetCompatibleChannelMeta(channelType int) (string, []string) { func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -64,10 +60,6 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "doubao", doubao.ModelList return "doubao", doubao.ModelList
case channeltype.Novita: case channeltype.Novita:
return "novita", novita.ModelList return "novita", novita.ModelList
case channeltype.SiliconFlow:
return "siliconflow", siliconflow.ModelList
case channeltype.XAI:
return "xai", xai.ModelList
default: default:
return "openai", ModelList return "openai", ModelList
} }

View File

@@ -8,10 +8,6 @@ var ModelList = []string{
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o", "gpt-4o-2024-05-13",
"gpt-4o-2024-08-06",
"gpt-4o-2024-11-20",
"chatgpt-4o-latest",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",
@@ -21,7 +17,4 @@ var ModelList = []string{
"dall-e-2", "dall-e-3", "dall-e-2", "dall-e-3",
"whisper-1", "whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106", "tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
"o1", "o1-2024-12-17",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
} }

View File

@@ -2,16 +2,15 @@ package openai
import ( import (
"fmt" "fmt"
"strings"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"strings"
) )
func ResponseText2Usage(responseText string, modelName string, promptTokens int) *model.Usage { func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{} usage := &model.Usage{}
usage.PromptTokens = promptTokens usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modelName) usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage return usage
} }

View File

@@ -55,8 +55,8 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
render.StringData(c, data) // if error happened, pass the data to client render.StringData(c, data) // if error happened, pass the data to client
continue // just ignore the error continue // just ignore the error
} }
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { if len(streamResponse.Choices) == 0 {
// but for empty choice and no usage, we should not pass it to client, this is for azure // but for empty choice, we should not pass it to client, this is for azure
continue // just ignore empty choice continue // just ignore empty choice
} }
render.StringData(c, data) render.StringData(c, data)

View File

@@ -110,7 +110,7 @@ func CountTokenMessages(messages []model.Message, model string) int {
if imageUrl["detail"] != nil { if imageUrl["detail"] != nil {
detail = imageUrl["detail"].(string) detail = imageUrl["detail"].(string)
} }
imageTokens, err := countImageTokens(url, detail, model) imageTokens, err := countImageTokens(url, detail)
if err != nil { if err != nil {
logger.SysError("error counting image tokens: " + err.Error()) logger.SysError("error counting image tokens: " + err.Error())
} else { } else {
@@ -134,15 +134,11 @@ const (
lowDetailCost = 85 lowDetailCost = 85
highDetailCostPerTile = 170 highDetailCostPerTile = 170
additionalCost = 85 additionalCost = 85
// gpt-4o-mini cost higher than other model
gpt4oMiniLowDetailCost = 2833
gpt4oMiniHighDetailCost = 5667
gpt4oMiniAdditionalCost = 2833
) )
// https://platform.openai.com/docs/guides/vision/calculating-costs // https://platform.openai.com/docs/guides/vision/calculating-costs
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb // https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
func countImageTokens(url string, detail string, model string) (_ int, err error) { func countImageTokens(url string, detail string) (_ int, err error) {
var fetchSize = true var fetchSize = true
var width, height int var width, height int
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding // Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
@@ -176,9 +172,6 @@ func countImageTokens(url string, detail string, model string) (_ int, err error
} }
switch detail { switch detail {
case "low": case "low":
if strings.HasPrefix(model, "gpt-4o-mini") {
return gpt4oMiniLowDetailCost, nil
}
return lowDetailCost, nil return lowDetailCost, nil
case "high": case "high":
if fetchSize { if fetchSize {
@@ -198,9 +191,6 @@ func countImageTokens(url string, detail string, model string) (_ int, err error
height = int(float64(height) * ratio) height = int(float64(height) * ratio)
} }
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512)) numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
if strings.HasPrefix(model, "gpt-4o-mini") {
return numSquares*gpt4oMiniHighDetailCost + gpt4oMiniAdditionalCost, nil
}
result := numSquares*highDetailCostPerTile + additionalCost result := numSquares*highDetailCostPerTile + additionalCost
return result, nil return result, nil
default: default:

View File

@@ -1,16 +1,8 @@
package openai package openai
import ( import "github.com/songquanpeng/one-api/relay/model"
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
)
func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode {
logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err))
Error := model.Error{ Error := model.Error{
Message: err.Error(), Message: err.Error(),
Type: "one_api_error", Type: "one_api_error",

View File

@@ -19,11 +19,11 @@ type Prompt struct {
} }
type ChatRequest struct { type ChatRequest struct {
Prompt Prompt `json:"prompt"` Prompt Prompt `json:"prompt"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"` CandidateCount int `json:"candidateCount,omitempty"`
TopP *float64 `json:"topP,omitempty"` TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"` TopK int `json:"topK,omitempty"`
} }
type Error struct { type Error struct {

View File

@@ -1,136 +0,0 @@
package replicate
import (
"fmt"
"io"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
meta *meta.Meta
}
// ConvertImageRequest implements adaptor.Adaptor.
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return DrawImageRequest{
Input: ImageInput{
Steps: 25,
Prompt: request.Prompt,
Guidance: 3,
Seed: int(time.Now().UnixNano()),
SafetyTolerance: 5,
NImages: 1, // replicate will always return 1 image
Width: 1440,
Height: 1440,
AspectRatio: "1:1",
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if !request.Stream {
// TODO: support non-stream mode
return nil, errors.Errorf("replicate models only support stream mode now, please set stream=true")
}
// Build the prompt from OpenAI messages
var promptBuilder strings.Builder
for _, message := range request.Messages {
switch msgCnt := message.Content.(type) {
case string:
promptBuilder.WriteString(message.Role)
promptBuilder.WriteString(": ")
promptBuilder.WriteString(msgCnt)
promptBuilder.WriteString("\n")
default:
}
}
replicateRequest := ReplicateChatRequest{
Input: ChatInput{
Prompt: promptBuilder.String(),
MaxTokens: request.MaxTokens,
Temperature: 1.0,
TopP: 1.0,
PresencePenalty: 0.0,
FrequencyPenalty: 0.0,
},
}
// Map optional fields
if request.Temperature != nil {
replicateRequest.Input.Temperature = *request.Temperature
}
if request.TopP != nil {
replicateRequest.Input.TopP = *request.TopP
}
if request.PresencePenalty != nil {
replicateRequest.Input.PresencePenalty = *request.PresencePenalty
}
if request.FrequencyPenalty != nil {
replicateRequest.Input.FrequencyPenalty = *request.FrequencyPenalty
}
if request.MaxTokens > 0 {
replicateRequest.Input.MaxTokens = request.MaxTokens
} else if request.MaxTokens == 0 {
replicateRequest.Input.MaxTokens = 500
}
return replicateRequest, nil
}
func (a *Adaptor) Init(meta *meta.Meta) {
a.meta = meta
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
if !slices.Contains(ModelList, meta.OriginModelName) {
return "", errors.Errorf("model %s not supported", meta.OriginModelName)
}
return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
logger.Info(c, "send request to replicate")
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
case relaymode.ChatCompletions:
err, usage = ChatHandler(c, resp)
default:
err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "replicate"
}

View File

@@ -1,191 +0,0 @@
package replicate
import (
"bufio"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
func ChatHandler(c *gin.Context, resp *http.Response) (
srvErr *model.ErrorWithStatusCode, usage *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ChatResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ChatResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
if taskData.URLs.Stream == "" {
return errors.New("stream url is empty")
}
// request stream url
responseText, err := chatStreamHandler(c, taskData.URLs.Stream)
if err != nil {
return errors.Wrap(err, "chat stream handler")
}
ctxMeta := meta.GetByContext(c)
usage = openai.ResponseText2Usage(responseText,
ctxMeta.ActualModelName, ctxMeta.PromptTokens)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "chat_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, usage
}
const (
eventPrefix = "event: "
dataPrefix = "data: "
done = "[DONE]"
)
func chatStreamHandler(c *gin.Context, streamUrl string) (responseText string, err error) {
// request stream endpoint
streamReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, streamUrl, nil)
if err != nil {
return "", errors.Wrap(err, "new request to stream")
}
streamReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
streamReq.Header.Set("Accept", "text/event-stream")
streamReq.Header.Set("Cache-Control", "no-store")
resp, err := http.DefaultClient.Do(streamReq)
if err != nil {
return "", errors.Wrap(err, "do request to stream")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(resp.Body)
return "", errors.Errorf("bad status code [%d]%s", resp.StatusCode, string(payload))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
common.SetEventStreamHeaders(c)
doneRendered := false
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" {
continue
}
// Handle comments starting with ':'
if strings.HasPrefix(line, ":") {
continue
}
// Parse SSE fields
if strings.HasPrefix(line, eventPrefix) {
event := strings.TrimSpace(line[len(eventPrefix):])
var data string
// Read the following lines to get data and id
for scanner.Scan() {
nextLine := scanner.Text()
if nextLine == "" {
break
}
if strings.HasPrefix(nextLine, dataPrefix) {
data = nextLine[len(dataPrefix):]
} else if strings.HasPrefix(nextLine, "id:") {
// id = strings.TrimSpace(nextLine[len("id:"):])
}
}
if event == "output" {
render.StringData(c, data)
responseText += data
} else if event == "done" {
render.Done(c)
doneRendered = true
break
}
}
}
if err := scanner.Err(); err != nil {
return "", errors.Wrap(err, "scan stream")
}
if !doneRendered {
render.Done(c)
}
return responseText, nil
}

View File

@@ -1,58 +0,0 @@
package replicate
// ModelList is a list of models that can be used with Replicate.
//
// https://replicate.com/pricing
var ModelList = []string{
// -------------------------------------
// image model
// -------------------------------------
"black-forest-labs/flux-1.1-pro",
"black-forest-labs/flux-1.1-pro-ultra",
"black-forest-labs/flux-canny-dev",
"black-forest-labs/flux-canny-pro",
"black-forest-labs/flux-depth-dev",
"black-forest-labs/flux-depth-pro",
"black-forest-labs/flux-dev",
"black-forest-labs/flux-dev-lora",
"black-forest-labs/flux-fill-dev",
"black-forest-labs/flux-fill-pro",
"black-forest-labs/flux-pro",
"black-forest-labs/flux-redux-dev",
"black-forest-labs/flux-redux-schnell",
"black-forest-labs/flux-schnell",
"black-forest-labs/flux-schnell-lora",
"ideogram-ai/ideogram-v2",
"ideogram-ai/ideogram-v2-turbo",
"recraft-ai/recraft-v3",
"recraft-ai/recraft-v3-svg",
"stability-ai/stable-diffusion-3",
"stability-ai/stable-diffusion-3.5-large",
"stability-ai/stable-diffusion-3.5-large-turbo",
"stability-ai/stable-diffusion-3.5-medium",
// -------------------------------------
// language model
// -------------------------------------
"ibm-granite/granite-20b-code-instruct-8k",
"ibm-granite/granite-3.0-2b-instruct",
"ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k",
"meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3.1-405b-instruct",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct",
"mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1",
"mistralai/mixtral-8x7b-instruct-v0.1",
// -------------------------------------
// video model
// -------------------------------------
// "minimax/video-01", // TODO: implement the adaptor
}

View File

@@ -1,222 +0,0 @@
package replicate
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"image"
"image/png"
"io"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"golang.org/x/image/webp"
"golang.org/x/sync/errgroup"
)
// ImagesEditsHandler just copy response body to client
//
// https://replicate.com/black-forest-labs/flux-fill-pro
// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
// c.Writer.WriteHeader(resp.StatusCode)
// for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
// }
// if _, err := io.Copy(c.Writer, resp.Body); err != nil {
// return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
// }
// defer resp.Body.Close()
// return nil, nil
// }
var errNextLoop = errors.New("next_loop")
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
if resp.StatusCode != http.StatusCreated {
payload, _ := io.ReadAll(resp.Body)
return openai.ErrorWrapper(
errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)),
"bad_status_code", http.StatusInternalServerError),
nil
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
respData := new(ImageResponse)
if err = json.Unmarshal(respBody, respData); err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
for {
err = func() error {
// get task
taskReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, respData.URLs.Get, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey)
taskResp, err := http.DefaultClient.Do(taskReq)
if err != nil {
return errors.Wrap(err, "get task")
}
defer taskResp.Body.Close()
if taskResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(taskResp.Body)
return errors.Errorf("bad status code [%d]%s",
taskResp.StatusCode, string(payload))
}
taskBody, err := io.ReadAll(taskResp.Body)
if err != nil {
return errors.Wrap(err, "read task response")
}
taskData := new(ImageResponse)
if err = json.Unmarshal(taskBody, taskData); err != nil {
return errors.Wrap(err, "decode task response")
}
switch taskData.Status {
case "succeeded":
case "failed", "canceled":
return errors.Errorf("task failed: %s", taskData.Status)
default:
time.Sleep(time.Second * 3)
return errNextLoop
}
output, err := taskData.GetOutput()
if err != nil {
return errors.Wrap(err, "get output")
}
if len(output) == 0 {
return errors.New("response output is empty")
}
var mu sync.Mutex
var pool errgroup.Group
respBody := &openai.ImageResponse{
Created: taskData.CompletedAt.Unix(),
Data: []openai.ImageData{},
}
for _, imgOut := range output {
imgOut := imgOut
pool.Go(func() error {
// download image
downloadReq, err := http.NewRequestWithContext(c.Request.Context(),
http.MethodGet, imgOut, nil)
if err != nil {
return errors.Wrap(err, "new request")
}
imgResp, err := http.DefaultClient.Do(downloadReq)
if err != nil {
return errors.Wrap(err, "download image")
}
defer imgResp.Body.Close()
if imgResp.StatusCode != http.StatusOK {
payload, _ := io.ReadAll(imgResp.Body)
return errors.Errorf("bad status code [%d]%s",
imgResp.StatusCode, string(payload))
}
imgData, err := io.ReadAll(imgResp.Body)
if err != nil {
return errors.Wrap(err, "read image")
}
imgData, err = ConvertImageToPNG(imgData)
if err != nil {
return errors.Wrap(err, "convert image")
}
mu.Lock()
respBody.Data = append(respBody.Data, openai.ImageData{
B64Json: fmt.Sprintf("data:image/png;base64,%s",
base64.StdEncoding.EncodeToString(imgData)),
})
mu.Unlock()
return nil
})
}
if err := pool.Wait(); err != nil {
if len(respBody.Data) == 0 {
return errors.WithStack(err)
}
logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err))
}
c.JSON(http.StatusOK, respBody)
return nil
}()
if err != nil {
if errors.Is(err, errNextLoop) {
continue
}
return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil
}
break
}
return nil, nil
}
// ConvertImageToPNG converts a WebP image to PNG format
func ConvertImageToPNG(webpData []byte) ([]byte, error) {
// bypass if it's already a PNG image
if bytes.HasPrefix(webpData, []byte("\x89PNG")) {
return webpData, nil
}
// check if is jpeg, convert to png
if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) {
img, _, err := image.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode jpeg")
}
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}
// Decode the WebP image
img, err := webp.Decode(bytes.NewReader(webpData))
if err != nil {
return nil, errors.Wrap(err, "decode webp")
}
// Encode the image as PNG
var pngBuffer bytes.Buffer
if err := png.Encode(&pngBuffer, img); err != nil {
return nil, errors.Wrap(err, "encode png")
}
return pngBuffer.Bytes(), nil
}

View File

@@ -1,159 +0,0 @@
package replicate
import (
"time"
"github.com/pkg/errors"
)
// DrawImageRequest draw image by fluxpro
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type DrawImageRequest struct {
Input ImageInput `json:"input"`
}
// ImageInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema
type ImageInput struct {
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
ImagePrompt string `json:"image_prompt"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
Interval int `json:"interval" binding:"required,min=1,max=4"`
AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
Seed int `json:"seed"`
NImages int `json:"n_images" binding:"required,min=1,max=8"`
Width int `json:"width" binding:"required,min=256,max=1440"`
Height int `json:"height" binding:"required,min=256,max=1440"`
}
// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type InpaintingImageByFlusReplicateRequest struct {
Input FluxInpaintingInput `json:"input"`
}
// FluxInpaintingInput is input of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema
type FluxInpaintingInput struct {
Mask string `json:"mask" binding:"required"`
Image string `json:"image" binding:"required"`
Seed int `json:"seed"`
Steps int `json:"steps" binding:"required,min=1"`
Prompt string `json:"prompt" binding:"required,min=5"`
Guidance int `json:"guidance" binding:"required,min=2,max=5"`
OutputFormat string `json:"output_format"`
SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"`
PromptUnsampling bool `json:"prompt_unsampling"`
}
// ImageResponse is response of DrawImageByFluxProRequest
//
// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json
type ImageResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input DrawImageRequest `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output any `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs FluxURLs `json:"urls"`
Version string `json:"version"`
}
func (r *ImageResponse) GetOutput() ([]string, error) {
switch v := r.Output.(type) {
case string:
return []string{v}, nil
case []string:
return v, nil
case nil:
return nil, nil
case []interface{}:
// convert []interface{} to []string
ret := make([]string, len(v))
for idx, vv := range v {
if vvv, ok := vv.(string); ok {
ret[idx] = vvv
} else {
return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv)
}
}
return ret, nil
default:
return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output)
}
}
// FluxMetrics is metrics of ImageResponse
type FluxMetrics struct {
ImageCount int `json:"image_count"`
PredictTime float64 `json:"predict_time"`
TotalTime float64 `json:"total_time"`
}
// FluxURLs is urls of ImageResponse
type FluxURLs struct {
Get string `json:"get"`
Cancel string `json:"cancel"`
}
type ReplicateChatRequest struct {
Input ChatInput `json:"input" form:"input" binding:"required"`
}
// ChatInput is input of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/api/schema
type ChatInput struct {
TopK int `json:"top_k"`
TopP float64 `json:"top_p"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
MinTokens int `json:"min_tokens"`
Temperature float64 `json:"temperature"`
SystemPrompt string `json:"system_prompt"`
StopSequences string `json:"stop_sequences"`
PromptTemplate string `json:"prompt_template"`
PresencePenalty float64 `json:"presence_penalty"`
FrequencyPenalty float64 `json:"frequency_penalty"`
}
// ChatResponse is response of ChatByReplicateRequest
//
// https://replicate.com/meta/meta-llama-3.1-405b-instruct/examples?input=http&output=json
type ChatResponse struct {
CompletedAt time.Time `json:"completed_at"`
CreatedAt time.Time `json:"created_at"`
DataRemoved bool `json:"data_removed"`
Error string `json:"error"`
ID string `json:"id"`
Input ChatInput `json:"input"`
Logs string `json:"logs"`
Metrics FluxMetrics `json:"metrics"`
// Output could be `string` or `[]string`
Output []string `json:"output"`
StartedAt time.Time `json:"started_at"`
Status string `json:"status"`
URLs ChatResponseUrl `json:"urls"`
Version string `json:"version"`
}
// ChatResponseUrl is task urls of ChatResponse
type ChatResponseUrl struct {
Stream string `json:"stream"`
Get string `json:"get"`
Cancel string `json:"cancel"`
}

View File

@@ -1,36 +0,0 @@
package siliconflow
// https://docs.siliconflow.cn/docs/getting-started
var ModelList = []string{
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"THUDM/glm-4-9b-chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
}

View File

@@ -1,13 +1,7 @@
package stepfun package stepfun
var ModelList = []string{ var ModelList = []string{
"step-1-8k",
"step-1-32k", "step-1-32k",
"step-1-128k",
"step-1-256k",
"step-1-flash",
"step-2-16k",
"step-1v-8k",
"step-1v-32k", "step-1v-32k",
"step-1x-medium", "step-1-200k",
} }

View File

@@ -2,19 +2,16 @@ package tencent
import ( import (
"errors" "errors"
"io"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "io"
"net/http"
"strconv"
"strings"
) )
// https://cloud.tencent.com/document/api/1729/101837 // https://cloud.tencent.com/document/api/1729/101837
@@ -55,18 +52,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if err != nil { if err != nil {
return nil, err return nil, err
} }
var convertedRequest any tencentRequest := ConvertRequest(*request)
switch relayMode {
case relaymode.Embeddings:
a.Action = "GetEmbedding"
convertedRequest = ConvertEmbeddingRequest(*request)
default:
a.Action = "ChatCompletions"
convertedRequest = ConvertRequest(*request)
}
// we have to calculate the sign here // we have to calculate the sign here
a.Sign = GetSign(convertedRequest, a, secretId, secretKey) a.Sign = GetSign(*tencentRequest, a, secretId, secretKey)
return convertedRequest, nil return tencentRequest, nil
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
@@ -86,12 +75,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else { } else {
switch meta.Mode { err, usage = Handler(c, resp)
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
} }
return return
} }

View File

@@ -5,6 +5,4 @@ var ModelList = []string{
"hunyuan-standard", "hunyuan-standard",
"hunyuan-standard-256K", "hunyuan-standard-256K",
"hunyuan-pro", "hunyuan-pro",
"hunyuan-vision",
"hunyuan-embedding",
} }

View File

@@ -8,6 +8,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@@ -15,14 +16,11 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@@ -41,73 +39,13 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
Model: &request.Model, Model: &request.Model,
Stream: &request.Stream, Stream: &request.Stream,
Messages: messages, Messages: messages,
TopP: request.TopP, TopP: &request.TopP,
Temperature: request.Temperature, Temperature: &request.Temperature,
} }
} }
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
InputList: request.ParseInput(),
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var tencentResponseP EmbeddingResponseP
err := json.NewDecoder(resp.Body).Decode(&tencentResponseP)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
tencentResponse := tencentResponseP.Response
if tencentResponse.Error.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: tencentResponse.Error.Message,
Code: tencentResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
requestModel := c.GetString(ctxkey.RequestModel)
fullTextResponse := embeddingResponseTencent2OpenAI(&tencentResponse)
fullTextResponse.Model = requestModel
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseTencent2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "hunyuan-embedding",
Usage: model.Usage{TotalTokens: response.EmbeddingUsage.TotalTokens},
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: response.ReqID,
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Usage: model.Usage{ Usage: model.Usage{
@@ -210,7 +148,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
TencentResponse = responseP.Response TencentResponse = responseP.Response
if TencentResponse.Error.Code != "" { if TencentResponse.Error.Code != 0 {
return &model.ErrorWithStatusCode{ return &model.ErrorWithStatusCode{
Error: model.Error{ Error: model.Error{
Message: TencentResponse.Error.Message, Message: TencentResponse.Error.Message,
@@ -257,7 +195,7 @@ func hmacSha256(s, key string) string {
return string(hashed.Sum(nil)) return string(hashed.Sum(nil))
} }
func GetSign(req any, adaptor *Adaptor, secId, secKey string) string { func GetSign(req ChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string // build canonical request string
host := "hunyuan.tencentcloudapi.com" host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST" httpRequestMethod := "POST"

View File

@@ -35,16 +35,16 @@ type ChatRequest struct {
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。 // 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。 // 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。 // 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP,omitempty"` TopP *float64 `json:"TopP"`
// 说明: // 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。 // 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。 // 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。 // 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature,omitempty"` Temperature *float64 `json:"Temperature"`
} }
type Error struct { type Error struct {
Code string `json:"Code"` Code int `json:"Code"`
Message string `json:"Message"` Message string `json:"Message"`
} }
@@ -61,41 +61,15 @@ type ResponseChoices struct {
} }
type ChatResponse struct { type ChatResponse struct {
Choices []ResponseChoices `json:"Choices,omitempty"` // 结果 Choices []ResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串 Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id Id string `json:"Id,omitempty"` // 会话 id
Usage Usage `json:"Usage,omitempty"` // token 数量 Usage Usage `json:"Usage,omitempty"` // token 数量
Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值 Error Error `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释 Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"RequestId,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参 ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
} }
type ChatResponseP struct { type ChatResponseP struct {
Response ChatResponse `json:"Response,omitempty"` Response ChatResponse `json:"Response,omitempty"`
} }
type EmbeddingRequest struct {
InputList []string `json:"InputList"`
}
type EmbeddingData struct {
Embedding []float64 `json:"Embedding"`
Index int `json:"Index"`
Object string `json:"Object"`
}
type EmbeddingUsage struct {
PromptTokens int `json:"PromptTokens"`
TotalTokens int `json:"TotalTokens"`
}
type EmbeddingResponse struct {
Data []EmbeddingData `json:"Data"`
EmbeddingUsage EmbeddingUsage `json:"Usage,omitempty"`
RequestId string `json:"RequestId,omitempty"`
Error Error `json:"Error,omitempty"`
}
type EmbeddingResponseP struct {
Response EmbeddingResponse `json:"Response,omitempty"`
}

View File

@@ -13,12 +13,7 @@ import (
) )
var ModelList = []string{ var ModelList = []string{
"claude-3-haiku@20240307", "claude-3-haiku@20240307", "claude-3-opus@20240229", "claude-3-5-sonnet@20240620", "claude-3-sonnet@20240229",
"claude-3-sonnet@20240229",
"claude-3-opus@20240229",
"claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-v2@20241022",
"claude-3-5-haiku@20241022",
} }
const anthropicVersion = "vertex-2023-10-16" const anthropicVersion = "vertex-2023-10-16"

View File

@@ -11,8 +11,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`

View File

@@ -15,11 +15,7 @@ import (
) )
var ModelList = []string{ var ModelList = []string{
"gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"gemini-1.5-pro-001", "gemini-1.5-flash-001",
"gemini-1.5-pro-002", "gemini-1.5-flash-002",
"gemini-2.0-flash-exp",
"gemini-2.0-flash-thinking-exp", "gemini-2.0-flash-thinking-exp-01-21",
} }
type Adaptor struct { type Adaptor struct {

View File

@@ -1,5 +0,0 @@
package xai
var ModelList = []string{
"grok-beta",
}

View File

@@ -5,8 +5,6 @@ var ModelList = []string{
"SparkDesk-v1.1", "SparkDesk-v1.1",
"SparkDesk-v2.1", "SparkDesk-v2.1",
"SparkDesk-v3.1", "SparkDesk-v3.1",
"SparkDesk-v3.1-128K",
"SparkDesk-v3.5", "SparkDesk-v3.5",
"SparkDesk-v3.5-32K",
"SparkDesk-v4.0", "SparkDesk-v4.0",
} }

View File

@@ -272,9 +272,9 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl,
} }
func parseAPIVersionByModelName(modelName string) string { func parseAPIVersionByModelName(modelName string) string {
index := strings.IndexAny(modelName, "-") parts := strings.Split(modelName, "-")
if index != -1 { if len(parts) == 2 {
return modelName[index+1:] return parts[1]
} }
return "" return ""
} }
@@ -283,17 +283,13 @@ func parseAPIVersionByModelName(modelName string) string {
func apiVersion2domain(apiVersion string) string { func apiVersion2domain(apiVersion string) string {
switch apiVersion { switch apiVersion {
case "v1.1": case "v1.1":
return "lite" return "general"
case "v2.1": case "v2.1":
return "generalv2" return "generalv2"
case "v3.1": case "v3.1":
return "generalv3" return "generalv3"
case "v3.1-128K":
return "pro-128k"
case "v3.5": case "v3.5":
return "generalv3.5" return "generalv3.5"
case "v3.5-32K":
return "max-32k"
case "v4.0": case "v4.0":
return "4.0Ultra" return "4.0Ultra"
} }
@@ -301,17 +297,7 @@ func apiVersion2domain(apiVersion string) string {
} }
func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) {
var authUrl string
domain := apiVersion2domain(apiVersion) domain := apiVersion2domain(apiVersion)
switch apiVersion { authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
case "v3.1-128K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/pro-128k"), apiKey, apiSecret)
break
case "v3.5-32K":
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/chat/max-32k"), apiKey, apiSecret)
break
default:
authUrl = buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
}
return domain, authUrl return domain, authUrl
} }

View File

@@ -19,11 +19,11 @@ type ChatRequest struct {
} `json:"header"` } `json:"header"`
Parameter struct { Parameter struct {
Chat struct { Chat struct {
Domain string `json:"domain,omitempty"` Domain string `json:"domain,omitempty"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"` Auditing bool `json:"auditing,omitempty"`
} `json:"chat"` } `json:"chat"`
} `json:"parameter"` } `json:"parameter"`
Payload struct { Payload struct {

View File

@@ -4,13 +4,13 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"math"
"net/http" "net/http"
"strings" "strings"
) )
@@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, err return baiduEmbeddingRequest, err
default: default:
// TopP [0.0, 1.0] // TopP (0.0, 1.0)
request.TopP = helper.Float64PtrMax(request.TopP, 1) request.TopP = math.Min(0.99, request.TopP)
request.TopP = helper.Float64PtrMin(request.TopP, 0) request.TopP = math.Max(0.01, request.TopP)
// Temperature [0.0, 1.0] // Temperature (0.0, 1.0)
request.Temperature = helper.Float64PtrMax(request.Temperature, 1) request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = helper.Float64PtrMin(request.Temperature, 0) request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model) a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" { if a.APIVersion == "v4" {
return request, nil return request, nil

View File

@@ -12,8 +12,8 @@ type Message struct {
type Request struct { type Request struct {
Prompt []Message `json:"prompt"` Prompt []Message `json:"prompt"`
Temperature *float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"` RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"` Incremental bool `json:"incremental,omitempty"`
} }

View File

@@ -19,7 +19,6 @@ const (
DeepL DeepL
VertexAI VertexAI
Proxy Proxy
Replicate
Dummy // this one is only for count, do not add any channel after this Dummy // this one is only for count, do not add any channel after this
) )

View File

@@ -3,7 +3,6 @@ package billing
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
) )
@@ -32,17 +31,8 @@ func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQ
} }
// totalQuota is total quota consumed // totalQuota is total quota consumed
if totalQuota != 0 { if totalQuota != 0 {
logContent := fmt.Sprintf("倍率%.2f × %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, &model.Log{ model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
UserId: userId,
ChannelId: channelId,
PromptTokens: int(totalQuota),
CompletionTokens: 0,
ModelName: modelName,
TokenName: tokenName,
Quota: int(totalQuota),
Content: logContent,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota) model.UpdateChannelUsedQuota(channelId, totalQuota)
} }

View File

@@ -30,14 +30,6 @@ var ImageSizeRatios = map[string]map[string]float64{
"720x1280": 1, "720x1280": 1,
"1280x720": 1, "1280x720": 1,
}, },
"step-1x-medium": {
"256x256": 1,
"512x512": 1,
"768x768": 1,
"1024x1024": 1,
"1280x800": 1,
"800x1280": 1,
},
} }
var ImageGenerationAmounts = map[string][2]int{ var ImageGenerationAmounts = map[string][2]int{
@@ -47,7 +39,6 @@ var ImageGenerationAmounts = map[string][2]int{
"ali-stable-diffusion-v1.5": {1, 4}, // Ali "ali-stable-diffusion-v1.5": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali "wanx-v1": {1, 4}, // Ali
"cogview-3": {1, 1}, "cogview-3": {1, 1},
"step-1x-medium": {1, 1},
} }
var ImagePromptLengthLimitations = map[string]int{ var ImagePromptLengthLimitations = map[string]int{
@@ -57,7 +48,6 @@ var ImagePromptLengthLimitations = map[string]int{
"ali-stable-diffusion-v1.5": 4000, "ali-stable-diffusion-v1.5": 4000,
"wanx-v1": 4000, "wanx-v1": 4000,
"cogview-3": 833, "cogview-3": 833,
"step-1x-medium": 4000,
} }
var ImageOriginModelName = map[string]string{ var ImageOriginModelName = map[string]string{

View File

@@ -9,10 +9,9 @@ import (
) )
const ( const (
USD2RMB = 7 USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500 USD = 500 // $0.002 = 1 -> $1 = 500
MILLI_USD = 1.0 / 1000 * USD RMB = USD / USD2RMB
RMB = USD / USD2RMB
) )
// ModelRatio // ModelRatio
@@ -29,20 +28,15 @@ var ModelRatio = map[string]float64{
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30, "gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens "gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens "gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-turbo": 5, // $0.01 / 1K tokens
"gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens
"gpt-4o": 2.5, // $0.005 / 1K tokens "gpt-4o": 2.5, // $0.005 / 1K tokens
"chatgpt-4o-latest": 2.5, // $0.005 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens
"gpt-4o-2024-05-13": 2.5, // $0.005 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-4o-2024-08-06": 1.25, // $0.0025 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-4o-2024-11-20": 1.25, // $0.0025 / 1K tokens
"gpt-4o-mini": 0.075, // $0.00015 / 1K tokens
"gpt-4o-mini-2024-07-18": 0.075, // $0.00015 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.25, // $0.0005 / 1K tokens
"gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
@@ -50,14 +44,8 @@ var ModelRatio = map[string]float64{
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"o1": 7.5, // $15.00 / 1M input tokens "davinci-002": 1, // $0.002 / 1K tokens
"o1-2024-12-17": 7.5, "babbage-002": 0.2, // $0.0004 / 1K tokens
"o1-preview": 7.5, // $15.00 / 1M input tokens
"o1-preview-2024-09-12": 7.5,
"o1-mini": 1.5, // $3.00 / 1M input tokens
"o1-mini-2024-09-12": 1.5,
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@@ -87,10 +75,8 @@ var ModelRatio = map[string]float64{
"claude-2.0": 8.0 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD, "claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD, "claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-5-haiku-20241022": 1.0 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-4.0-8K": 0.120 * RMB, "ERNIE-4.0-8K": 0.120 * RMB,
@@ -110,16 +96,12 @@ var ModelRatio = map[string]float64{
"bge-large-en": 0.002 * RMB, "bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB, "tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing // https://ai.google.dev/pricing
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "PaLM-2": 1,
"gemini-1.0-pro": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.5-pro": 1, "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-1.5-pro-001": 1, "gemini-1.0-pro-vision-001": 1,
"gemini-1.5-flash": 1, "gemini-1.0-pro-001": 1,
"gemini-1.5-flash-001": 1, "gemini-1.5-pro": 1,
"gemini-2.0-flash-exp": 1,
"gemini-2.0-flash-thinking-exp": 1,
"gemini-2.0-flash-thinking-exp-01-21": 1,
"aqa": 1,
// https://open.bigmodel.cn/pricing // https://open.bigmodel.cn/pricing
"glm-4": 0.1 * RMB, "glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB, "glm-4v": 0.1 * RMB,
@@ -131,94 +113,27 @@ var ModelRatio = map[string]float64{
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"cogview-3": 0.25 * RMB, "cogview-3": 0.25 * RMB,
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 1.4286, // ¥0.02 / 1k tokens "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-turbo-latest": 1.4286, "qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-plus": 1.4286, "qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-plus-latest": 1.4286, "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"qwen-max-latest": 1.4286, "ali-stable-diffusion-xl": 8,
"qwen-max-longcontext": 1.4286, "ali-stable-diffusion-v1.5": 8,
"qwen-vl-max": 1.4286, "wanx-v1": 8,
"qwen-vl-max-latest": 1.4286, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"qwen-vl-plus": 1.4286, "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"qwen-vl-plus-latest": 1.4286, "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"qwen-vl-ocr": 1.4286, "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"qwen-vl-ocr-latest": 1.4286, "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"qwen-audio-turbo": 1.4286, "SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"qwen-math-plus": 1.4286, "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"qwen-math-plus-latest": 1.4286, "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"qwen-math-turbo": 1.4286, "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"qwen-math-turbo-latest": 1.4286, "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"qwen-coder-plus": 1.4286, "hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"qwen-coder-plus-latest": 1.4286, "ChatStd": 0.01 * RMB,
"qwen-coder-turbo": 1.4286, "ChatPro": 0.1 * RMB,
"qwen-coder-turbo-latest": 1.4286,
"qwq-32b-preview": 1.4286,
"qwen2.5-72b-instruct": 1.4286,
"qwen2.5-32b-instruct": 1.4286,
"qwen2.5-14b-instruct": 1.4286,
"qwen2.5-7b-instruct": 1.4286,
"qwen2.5-3b-instruct": 1.4286,
"qwen2.5-1.5b-instruct": 1.4286,
"qwen2.5-0.5b-instruct": 1.4286,
"qwen2-72b-instruct": 1.4286,
"qwen2-57b-a14b-instruct": 1.4286,
"qwen2-7b-instruct": 1.4286,
"qwen2-1.5b-instruct": 1.4286,
"qwen2-0.5b-instruct": 1.4286,
"qwen1.5-110b-chat": 1.4286,
"qwen1.5-72b-chat": 1.4286,
"qwen1.5-32b-chat": 1.4286,
"qwen1.5-14b-chat": 1.4286,
"qwen1.5-7b-chat": 1.4286,
"qwen1.5-1.8b-chat": 1.4286,
"qwen1.5-0.5b-chat": 1.4286,
"qwen-72b-chat": 1.4286,
"qwen-14b-chat": 1.4286,
"qwen-7b-chat": 1.4286,
"qwen-1.8b-chat": 1.4286,
"qwen-1.8b-longcontext-chat": 1.4286,
"qwen2-vl-7b-instruct": 1.4286,
"qwen2-vl-2b-instruct": 1.4286,
"qwen-vl-v1": 1.4286,
"qwen-vl-chat-v1": 1.4286,
"qwen2-audio-instruct": 1.4286,
"qwen-audio-chat": 1.4286,
"qwen2.5-math-72b-instruct": 1.4286,
"qwen2.5-math-7b-instruct": 1.4286,
"qwen2.5-math-1.5b-instruct": 1.4286,
"qwen2-math-72b-instruct": 1.4286,
"qwen2-math-7b-instruct": 1.4286,
"qwen2-math-1.5b-instruct": 1.4286,
"qwen2.5-coder-32b-instruct": 1.4286,
"qwen2.5-coder-14b-instruct": 1.4286,
"qwen2.5-coder-7b-instruct": 1.4286,
"qwen2.5-coder-3b-instruct": 1.4286,
"qwen2.5-coder-1.5b-instruct": 1.4286,
"qwen2.5-coder-0.5b-instruct": 1.4286,
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"text-embedding-v3": 0.05,
"text-embedding-v2": 0.05,
"text-embedding-async-v2": 0.05,
"text-embedding-async-v1": 0.05,
"ali-stable-diffusion-xl": 8.00,
"ali-stable-diffusion-v1.5": 8.00,
"wanx-v1": 8.00,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.1-128K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v3.5-32K": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v4.0": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing // https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB, "moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB, "moonshot-v1-32k": 0.024 * RMB,
@@ -241,35 +156,20 @@ var ModelRatio = map[string]float64{
"mistral-large-latest": 8.0 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed // https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed
"gemma-7b-it": 0.07 / 1000000 * USD, "llama3-70b-8192": 0.59 / 1000 * USD,
"gemma2-9b-it": 0.20 / 1000000 * USD, "mixtral-8x7b-32768": 0.27 / 1000 * USD,
"llama-3.1-70b-versatile": 0.59 / 1000000 * USD, "llama3-8b-8192": 0.05 / 1000 * USD,
"llama-3.1-8b-instant": 0.05 / 1000000 * USD, "gemma-7b-it": 0.1 / 1000 * USD,
"llama-3.2-11b-text-preview": 0.05 / 1000000 * USD, "llama2-70b-4096": 0.64 / 1000 * USD,
"llama-3.2-11b-vision-preview": 0.05 / 1000000 * USD, "llama2-7b-2048": 0.1 / 1000 * USD,
"llama-3.2-1b-preview": 0.05 / 1000000 * USD,
"llama-3.2-3b-preview": 0.05 / 1000000 * USD,
"llama-3.2-90b-text-preview": 0.59 / 1000000 * USD,
"llama-guard-3-8b": 0.05 / 1000000 * USD,
"llama3-70b-8192": 0.59 / 1000000 * USD,
"llama3-8b-8192": 0.05 / 1000000 * USD,
"llama3-groq-70b-8192-tool-use-preview": 0.89 / 1000000 * USD,
"llama3-groq-8b-8192-tool-use-preview": 0.19 / 1000000 * USD,
"mixtral-8x7b-32768": 0.24 / 1000000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元 // https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000 * RMB, "yi-34b-chat-0205": 2.5 / 1000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000 * RMB, "yi-34b-chat-200k": 12.0 / 1000 * RMB,
"yi-vl-plus": 6.0 / 1000 * RMB, "yi-vl-plus": 6.0 / 1000 * RMB,
// https://platform.stepfun.com/docs/pricing/details // stepfun todo
"step-1-8k": 0.005 / 1000 * RMB, "step-1v-32k": 0.024 * RMB,
"step-1-32k": 0.015 / 1000 * RMB, "step-1-32k": 0.024 * RMB,
"step-1-128k": 0.040 / 1000 * RMB, "step-1-200k": 0.15 * RMB,
"step-1-256k": 0.095 / 1000 * RMB,
"step-1-flash": 0.001 / 1000 * RMB,
"step-2-16k": 0.038 / 1000 * RMB,
"step-1v-8k": 0.005 / 1000 * RMB,
"step-1v-32k": 0.015 / 1000 * RMB,
// aws llama3 https://aws.amazon.com/cn/bedrock/pricing/ // aws llama3 https://aws.amazon.com/cn/bedrock/pricing/
"llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens "llama3-8b-8192(33)": 0.0003 / 0.002, // $0.0003 / 1K tokens
"llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens "llama3-70b-8192(33)": 0.00265 / 0.002, // $0.00265 / 1K tokens
@@ -281,75 +181,22 @@ var ModelRatio = map[string]float64{
"command-r": 0.5 / 1000 * USD, "command-r": 0.5 / 1000 * USD,
"command-r-plus": 3.0 / 1000 * USD, "command-r-plus": 3.0 / 1000 * USD,
// https://platform.deepseek.com/api-docs/pricing/ // https://platform.deepseek.com/api-docs/pricing/
"deepseek-chat": 0.14 * MILLI_USD, "deepseek-chat": 1.0 / 1000 * RMB,
"deepseek-reasoner": 0.55 * MILLI_USD, "deepseek-coder": 1.0 / 1000 * RMB,
// https://www.deepl.com/pro?cta=header-prices // https://www.deepl.com/pro?cta=header-prices
"deepl-zh": 25.0 / 1000 * USD, "deepl-zh": 25.0 / 1000 * USD,
"deepl-en": 25.0 / 1000 * USD, "deepl-en": 25.0 / 1000 * USD,
"deepl-ja": 25.0 / 1000 * USD, "deepl-ja": 25.0 / 1000 * USD,
// https://console.x.ai/
"grok-beta": 5.0 / 1000 * USD,
// replicate charges based on the number of generated images
// https://replicate.com/pricing
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
"black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD,
"black-forest-labs/flux-canny-dev": 0.025 * USD,
"black-forest-labs/flux-canny-pro": 0.05 * USD,
"black-forest-labs/flux-depth-dev": 0.025 * USD,
"black-forest-labs/flux-depth-pro": 0.05 * USD,
"black-forest-labs/flux-dev": 0.025 * USD,
"black-forest-labs/flux-dev-lora": 0.032 * USD,
"black-forest-labs/flux-fill-dev": 0.04 * USD,
"black-forest-labs/flux-fill-pro": 0.05 * USD,
"black-forest-labs/flux-pro": 0.055 * USD,
"black-forest-labs/flux-redux-dev": 0.025 * USD,
"black-forest-labs/flux-redux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell": 0.003 * USD,
"black-forest-labs/flux-schnell-lora": 0.02 * USD,
"ideogram-ai/ideogram-v2": 0.08 * USD,
"ideogram-ai/ideogram-v2-turbo": 0.05 * USD,
"recraft-ai/recraft-v3": 0.04 * USD,
"recraft-ai/recraft-v3-svg": 0.08 * USD,
"stability-ai/stable-diffusion-3": 0.035 * USD,
"stability-ai/stable-diffusion-3.5-large": 0.065 * USD,
"stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD,
"stability-ai/stable-diffusion-3.5-medium": 0.035 * USD,
// replicate chat models
"ibm-granite/granite-20b-code-instruct-8k": 0.100 * USD,
"ibm-granite/granite-3.0-2b-instruct": 0.030 * USD,
"ibm-granite/granite-3.0-8b-instruct": 0.050 * USD,
"ibm-granite/granite-8b-code-instruct-128k": 0.050 * USD,
"meta/llama-2-13b": 0.100 * USD,
"meta/llama-2-13b-chat": 0.100 * USD,
"meta/llama-2-70b": 0.650 * USD,
"meta/llama-2-70b-chat": 0.650 * USD,
"meta/llama-2-7b": 0.050 * USD,
"meta/llama-2-7b-chat": 0.050 * USD,
"meta/meta-llama-3.1-405b-instruct": 9.500 * USD,
"meta/meta-llama-3-70b": 0.650 * USD,
"meta/meta-llama-3-70b-instruct": 0.650 * USD,
"meta/meta-llama-3-8b": 0.050 * USD,
"meta/meta-llama-3-8b-instruct": 0.050 * USD,
"mistralai/mistral-7b-instruct-v0.2": 0.050 * USD,
"mistralai/mistral-7b-v0.1": 0.050 * USD,
"mistralai/mixtral-8x7b-instruct-v0.1": 0.300 * USD,
} }
var CompletionRatio = map[string]float64{ var CompletionRatio = map[string]float64{
// aws llama3 // aws llama3
"llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-8b-8192(33)": 0.0006 / 0.0003,
"llama3-70b-8192(33)": 0.0035 / 0.00265, "llama3-70b-8192(33)": 0.0035 / 0.00265,
// whisper
"whisper-1": 0, // only count input tokens
// deepseek
"deepseek-chat": 0.28 / 0.14,
"deepseek-reasoner": 2.19 / 0.55,
} }
var ( var DefaultModelRatio map[string]float64
DefaultModelRatio map[string]float64 var DefaultCompletionRatio map[string]float64
DefaultCompletionRatio map[string]float64
)
func init() { func init() {
DefaultModelRatio = make(map[string]float64) DefaultModelRatio = make(map[string]float64)
@@ -461,25 +308,13 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 4.0 / 3.0 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") { if strings.HasPrefix(name, "gpt-4") {
if strings.HasPrefix(name, "gpt-4o") {
if name == "gpt-4o-2024-05-13" {
return 3
}
return 4
}
if strings.HasPrefix(name, "gpt-4-turbo") || if strings.HasPrefix(name, "gpt-4-turbo") ||
strings.HasPrefix(name, "gpt-4o") ||
strings.HasSuffix(name, "preview") { strings.HasSuffix(name, "preview") {
return 3 return 3
} }
return 2 return 2
} }
// including o1, o1-preview, o1-mini
if strings.HasPrefix(name, "o1") {
return 4
}
if name == "chatgpt-4o-latest" {
return 3
}
if strings.HasPrefix(name, "claude-3") { if strings.HasPrefix(name, "claude-3") {
return 5 return 5
} }
@@ -495,7 +330,6 @@ func GetCompletionRatio(name string, channelType int) float64 {
if strings.HasPrefix(name, "deepseek-") { if strings.HasPrefix(name, "deepseek-") {
return 2 return 2
} }
switch name { switch name {
case "llama2-70b-4096": case "llama2-70b-4096":
return 0.8 / 0.64 return 0.8 / 0.64
@@ -509,37 +343,6 @@ func GetCompletionRatio(name string, channelType int) float64 {
return 3 return 3
case "command-r-plus": case "command-r-plus":
return 5 return 5
case "grok-beta":
return 3
// Replicate Models
// https://replicate.com/pricing
case "ibm-granite/granite-20b-code-instruct-8k":
return 5
case "ibm-granite/granite-3.0-2b-instruct":
return 8.333333333333334
case "ibm-granite/granite-3.0-8b-instruct",
"ibm-granite/granite-8b-code-instruct-128k":
return 5
case "meta/llama-2-13b",
"meta/llama-2-13b-chat",
"meta/llama-2-7b",
"meta/llama-2-7b-chat",
"meta/meta-llama-3-8b",
"meta/meta-llama-3-8b-instruct":
return 5
case "meta/llama-2-70b",
"meta/llama-2-70b-chat",
"meta/meta-llama-3-70b",
"meta/meta-llama-3-70b-instruct":
return 2.750 / 0.650 // ≈4.230769
case "meta/meta-llama-3.1-405b-instruct":
return 1
case "mistralai/mistral-7b-instruct-v0.2",
"mistralai/mistral-7b-v0.1":
return 5
case "mistralai/mixtral-8x7b-instruct-v0.1":
return 1.000 / 0.300 // ≈3.333333
} }
return 1 return 1
} }

View File

@@ -45,8 +45,5 @@ const (
Novita Novita
VertextAI VertextAI
Proxy Proxy
SiliconFlow
XAI
Replicate
Dummy Dummy
) )

View File

@@ -37,8 +37,6 @@ func ToAPIType(channelType int) int {
apiType = apitype.DeepL apiType = apitype.DeepL
case VertextAI: case VertextAI:
apiType = apitype.VertexAI apiType = apitype.VertexAI
case Replicate:
apiType = apitype.Replicate
case Proxy: case Proxy:
apiType = apitype.Proxy apiType = apitype.Proxy
} }

View File

@@ -45,9 +45,6 @@ var ChannelBaseURLs = []string{
"https://api.novita.ai/v3/openai", // 41 "https://api.novita.ai/v3/openai", // 41
"", // 42 "", // 42
"", // 43 "", // 43
"https://api.siliconflow.cn", // 44
"https://api.x.ai", // 45
"https://api.replicate.com/v1/models/", // 46
} }
func init() { func init() {

View File

@@ -1,6 +1,5 @@
package role package role
const ( const (
System = "system"
Assistant = "assistant" Assistant = "assistant"
) )

View File

@@ -110,9 +110,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}() }()
// map model name // map model name
modelMapping := c.GetStringMapString(ctxkey.ModelMapping) modelMapping := c.GetString(ctxkey.ModelMapping)
if modelMapping != nil && modelMapping[audioModel] != "" { if modelMapping != "" {
audioModel = modelMapping[audioModel] modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[audioModel] != "" {
audioModel = modelMap[audioModel]
}
} }
baseURL := channeltype.ChannelBaseURLs[channelType] baseURL := channeltype.ChannelBaseURLs[channelType]

View File

@@ -8,11 +8,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/constant/role"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@@ -94,7 +90,7 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil return preConsumedQuota, nil
} }
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64, systemPromptReset bool) { func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
if usage == nil { if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected") logger.Error(ctx, "usage is nil, which is unexpected")
return return
@@ -122,20 +118,8 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
if err != nil { if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error()) logger.Error(ctx, "error update user quota cache: "+err.Error())
} }
logContent := fmt.Sprintf("倍率%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
model.RecordConsumeLog(ctx, &model.Log{ model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, promptTokens, completionTokens, textRequest.Model, meta.TokenName, quota, logContent)
UserId: meta.UserId,
ChannelId: meta.ChannelId,
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
ModelName: textRequest.Model,
TokenName: meta.TokenName,
Quota: int(quota),
Content: logContent,
IsStream: meta.IsStream,
ElapsedTime: helper.CalcElapsedTime(meta.StartTime),
SystemPromptReset: systemPromptReset,
})
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota)
} }
@@ -158,41 +142,15 @@ func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
} }
return true return true
} }
if resp.StatusCode != http.StatusOK && if resp.StatusCode != http.StatusOK {
// replicate return 201 to create a task
resp.StatusCode != http.StatusCreated {
return true return true
} }
if meta.ChannelType == channeltype.DeepL { if meta.ChannelType == channeltype.DeepL {
// skip stream check for deepl // skip stream check for deepl
return false return false
} }
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") {
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
// requiring the client to request the stream endpoint in the task info
meta.ChannelType != channeltype.Replicate {
return true return true
} }
return false return false
} }
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
if prompt == "" {
return false
}
if len(request.Messages) == 0 {
return false
}
if request.Messages[0].Role == role.System {
request.Messages[0].Content = prompt
logger.Infof(ctx, "rewrite system prompt")
return true
}
request.Messages = append([]relaymodel.Message{{
Role: role.System,
Content: prompt,
}}, request.Messages...)
logger.Infof(ctx, "add system prompt")
return true
}

View File

@@ -10,7 +10,6 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@@ -23,7 +22,7 @@ import (
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func getImageRequest(c *gin.Context, _ int) (*relaymodel.ImageRequest, error) { func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{} imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest) err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil { if err != nil {
@@ -66,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
return 1 return 1
} }
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode { func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length // check prompt length
if imageRequest.Prompt == "" { if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
@@ -151,12 +150,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
adaptor.Init(meta) adaptor.Init(meta)
// these adaptors need to convert the request
switch meta.ChannelType { switch meta.ChannelType {
case channeltype.Zhipu, case channeltype.Ali:
channeltype.Ali, fallthrough
channeltype.Replicate, case channeltype.Baidu:
channeltype.Baidu: fallthrough
case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest) finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
@@ -173,14 +172,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
ratio := modelRatio * groupRatio ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
var quota int64 quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
switch meta.ChannelType {
case channeltype.Replicate:
// replicate always return 1 image
quota = int64(ratio * imageCostRatio * 1000)
default:
quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
}
if userQuota-quota < 0 { if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@@ -194,9 +186,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
defer func(ctx context.Context) { defer func(ctx context.Context) {
if resp != nil && if resp != nil && resp.StatusCode != http.StatusOK {
resp.StatusCode != http.StatusCreated && // replicate returns 201
resp.StatusCode != http.StatusOK {
return return
} }
@@ -210,17 +200,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
} }
if quota != 0 { if quota != 0 {
tokenName := c.GetString(ctxkey.TokenName) tokenName := c.GetString(ctxkey.TokenName)
logContent := fmt.Sprintf("倍率%.2f × %.2f", modelRatio, groupRatio) logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, &model.Log{ model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent)
UserId: meta.UserId,
ChannelId: meta.ChannelId,
PromptTokens: 0,
CompletionTokens: 0,
ModelName: imageRequest.Model,
TokenName: tokenName,
Quota: int(quota),
Content: logContent,
})
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
model.UpdateChannelUsedQuota(channelId, quota) model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -8,11 +8,8 @@ import (
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing" "github.com/songquanpeng/one-api/relay/billing"
@@ -34,11 +31,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream meta.IsStream = textRequest.Stream
// map model name // map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping) textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model meta.ActualModelName = textRequest.Model
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
// get model ratio & group ratio // get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType) modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := billingratio.GetGroupRatio(meta.Group)
@@ -59,9 +55,30 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
adaptor.Init(meta) adaptor.Init(meta)
// get request body // get request body
requestBody, err := getRequestBody(c, meta, textRequest, adaptor) var requestBody io.Reader
if err != nil { if meta.APIType == apitype.OpenAI {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError) // no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
} }
// do request // do request
@@ -83,29 +100,6 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return respErr return respErr
} }
// post-consume quota // post-consume quota
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio, systemPromptReset) go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil return nil
} }
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}

Some files were not shown because too many files have changed in this diff Show More