mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-10-27 03:43:43 +08:00
Compare commits
49 Commits
v0.5.6-alp
...
v0.5.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
495fc628e4 | ||
|
|
76f9288c34 | ||
|
|
915d13fdd4 | ||
|
|
969f539777 | ||
|
|
54e5f8ecd2 | ||
|
|
34d517cfa2 | ||
|
|
ddcaf95f5f | ||
|
|
1d15157f7d | ||
|
|
de7b9710a5 | ||
|
|
58bb3ab6f6 | ||
|
|
d306cb5229 | ||
|
|
6c5307d0c4 | ||
|
|
7c4505bdfc | ||
|
|
9d43ec57d8 | ||
|
|
e5311892d1 | ||
|
|
bc7c9105f4 | ||
|
|
3fe76c8af7 | ||
|
|
c70c614018 | ||
|
|
0d87de697c | ||
|
|
aec343dc38 | ||
|
|
89d458b9cf | ||
|
|
63fafba112 | ||
|
|
a398f35968 | ||
|
|
57aa637c77 | ||
|
|
3b483639a4 | ||
|
|
22980b4c44 | ||
|
|
64cdb7eafb | ||
|
|
824444244b | ||
|
|
fbe9985f57 | ||
|
|
a27a5bcc06 | ||
|
|
e28d4b1741 | ||
|
|
f073592d39 | ||
|
|
fa41ca9805 | ||
|
|
e338de45b6 | ||
|
|
114587b46f | ||
|
|
b4b4acc288 | ||
|
|
d663de3e3a | ||
|
|
a85ecace2e | ||
|
|
fbdea91ea1 | ||
|
|
8d34b7a77e | ||
|
|
cbd62011b8 | ||
|
|
4701897e2e | ||
|
|
0f6c132a80 | ||
|
|
3cac45dc85 | ||
|
|
47c08c72ce | ||
|
|
53b2cace0b | ||
|
|
f0fc991b44 | ||
|
|
594f06e7b0 | ||
|
|
197d1d7a9d |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,4 +5,5 @@ upload
|
|||||||
*.db
|
*.db
|
||||||
build
|
build
|
||||||
*.db-journal
|
*.db-journal
|
||||||
logs
|
logs
|
||||||
|
data
|
||||||
@@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
|
|||||||
|
|
||||||
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
|
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
|
||||||
|
|
||||||
|
[](https://zeabur.com/templates/7Q0KO3)
|
||||||
|
|
||||||
1. First, fork the code.
|
1. First, fork the code.
|
||||||
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
|
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
|
||||||
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
|
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
|
||||||
|
|||||||
@@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for
|
|||||||
|
|
||||||
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
|
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
|
||||||
|
|
||||||
|
[](https://zeabur.com/templates/7Q0KO3)
|
||||||
|
|
||||||
1. まず、コードをフォークする。
|
1. まず、コードをフォークする。
|
||||||
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
|
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
|
||||||
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
|
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
|
||||||
|
|||||||
70
README.md
70
README.md
@@ -59,6 +59,9 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
> **Warning**
|
> **Warning**
|
||||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
||||||
|
|
||||||
|
> **Warning**
|
||||||
|
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
|
||||||
|
|
||||||
## 功能
|
## 功能
|
||||||
1. 支持多种大模型:
|
1. 支持多种大模型:
|
||||||
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||||
@@ -69,9 +72,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||||
+ [x] [360 智脑](https://ai.360.cn)
|
+ [x] [360 智脑](https://ai.360.cn)
|
||||||
|
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||||
2. 支持配置镜像以及众多第三方代理服务:
|
2. 支持配置镜像以及众多第三方代理服务:
|
||||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
+ [x] [CloseAI](https://referer.shadowai.xyz/r/2412)
|
||||||
+ [x] [API2D](https://api2d.com/r/197971)
|
+ [x] [API2D](https://api2d.com/r/197971)
|
||||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||||
@@ -88,26 +92,33 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
|||||||
12. 支持**用户邀请奖励**。
|
12. 支持**用户邀请奖励**。
|
||||||
13. 支持以美元为单位显示额度。
|
13. 支持以美元为单位显示额度。
|
||||||
14. 支持发布公告,设置充值链接,设置新用户初始额度。
|
14. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||||
15. 支持模型映射,重定向用户的请求模型。
|
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
|
||||||
16. 支持失败自动重试。
|
16. 支持失败自动重试。
|
||||||
17. 支持绘图接口。
|
17. 支持绘图接口。
|
||||||
18. 支持丰富的**自定义**设置,
|
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
|
||||||
|
19. 支持丰富的**自定义**设置,
|
||||||
1. 支持自定义系统名称,logo 以及页脚。
|
1. 支持自定义系统名称,logo 以及页脚。
|
||||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||||
19. 支持通过系统访问令牌访问管理 API。
|
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
|
||||||
20. 支持 Cloudflare Turnstile 用户校验。
|
21. 支持 Cloudflare Turnstile 用户校验。
|
||||||
21. 支持用户管理,支持**多种用户登录注册方式**:
|
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||||
|
|
||||||
## 部署
|
## 部署
|
||||||
### 基于 Docker 进行部署
|
### 基于 Docker 进行部署
|
||||||
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
|
```shell
|
||||||
|
# 使用 SQLite 的部署命令:
|
||||||
|
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||||
|
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
|
||||||
|
# 例如:
|
||||||
|
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||||
|
```
|
||||||
|
|
||||||
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
|
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
|
||||||
|
|
||||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||||
|
|
||||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
||||||
|
|
||||||
@@ -149,6 +160,19 @@ sudo service nginx restart
|
|||||||
|
|
||||||
初始账号用户名为 `root`,密码为 `123456`。
|
初始账号用户名为 `root`,密码为 `123456`。
|
||||||
|
|
||||||
|
|
||||||
|
### 基于 Docker Compose 进行部署
|
||||||
|
|
||||||
|
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
|
||||||
|
|
||||||
|
```shell
|
||||||
|
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# 查看部署状态
|
||||||
|
docker-compose ps
|
||||||
|
```
|
||||||
|
|
||||||
### 手动部署
|
### 手动部署
|
||||||
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
|
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
|
||||||
```shell
|
```shell
|
||||||
@@ -236,7 +260,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
<summary><strong>部署到 Zeabur</strong></summary>
|
<summary><strong>部署到 Zeabur</strong></summary>
|
||||||
<div>
|
<div>
|
||||||
|
|
||||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
|
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
|
||||||
|
|
||||||
|
[](https://zeabur.com/templates/7Q0KO3)
|
||||||
|
|
||||||
1. 首先 fork 一份代码。
|
1. 首先 fork 一份代码。
|
||||||
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
|
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
|
||||||
@@ -251,6 +277,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
|||||||
</div>
|
</div>
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><strong>部署到 Render</strong></summary>
|
||||||
|
<div>
|
||||||
|
|
||||||
|
> Render 提供免费额度,绑卡后可以进一步提升额度
|
||||||
|
|
||||||
|
Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com
|
||||||
|
|
||||||
|
</div>
|
||||||
|
</details>
|
||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
系统本身开箱即用。
|
系统本身开箱即用。
|
||||||
|
|
||||||
@@ -278,10 +315,11 @@ OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
|
|||||||
```mermaid
|
```mermaid
|
||||||
graph LR
|
graph LR
|
||||||
A(用户)
|
A(用户)
|
||||||
A --->|请求| B(One API)
|
A --->|使用 One API 分发的 key 进行请求| B(One API)
|
||||||
B -->|中继请求| C(OpenAI)
|
B -->|中继请求| C(OpenAI)
|
||||||
B -->|中继请求| D(Azure)
|
B -->|中继请求| D(Azure)
|
||||||
B -->|中继请求| E(其他下游渠道)
|
B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
|
||||||
|
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
|
||||||
```
|
```
|
||||||
|
|
||||||
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
|
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
|
||||||
@@ -329,6 +367,10 @@ graph LR
|
|||||||
13. 请求频率限制:
|
13. 请求频率限制:
|
||||||
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||||
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||||
|
14. 编码器缓存设置:
|
||||||
|
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||||
|
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||||
|
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||||
|
|
||||||
### 命令行参数
|
### 命令行参数
|
||||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||||
@@ -368,6 +410,12 @@ https://openai.justsong.cn
|
|||||||
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
||||||
6. 报错:`当前分组负载已饱和,请稍后再试`
|
6. 报错:`当前分组负载已饱和,请稍后再试`
|
||||||
+ 上游通道 429 了。
|
+ 上游通道 429 了。
|
||||||
|
7. 升级之后我的数据会丢失吗?
|
||||||
|
+ 如果使用 MySQL,不会。
|
||||||
|
+ 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
|
||||||
|
8. 升级之前数据库需要做变更吗?
|
||||||
|
+ 一般情况下不需要,系统将在初始化的时候自动调整。
|
||||||
|
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
|
||||||
|
|
||||||
## 相关项目
|
## 相关项目
|
||||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
||||||
|
|||||||
@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
|||||||
var DisplayInCurrencyEnabled = true
|
var DisplayInCurrencyEnabled = true
|
||||||
var DisplayTokenStatEnabled = true
|
var DisplayTokenStatEnabled = true
|
||||||
|
|
||||||
var UsingSQLite = false
|
|
||||||
|
|
||||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||||
|
|
||||||
var SessionSecret = uuid.New().String()
|
var SessionSecret = uuid.New().String()
|
||||||
var SQLitePath = "one-api.db"
|
|
||||||
|
|
||||||
var OptionMap map[string]string
|
var OptionMap map[string]string
|
||||||
var OptionMapRWMutex sync.RWMutex
|
var OptionMapRWMutex sync.RWMutex
|
||||||
@@ -98,6 +95,8 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
|||||||
var BatchUpdateEnabled = false
|
var BatchUpdateEnabled = false
|
||||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||||
|
|
||||||
|
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RequestIdKey = "X-Oneapi-Request-Id"
|
RequestIdKey = "X-Oneapi-Request-Id"
|
||||||
)
|
)
|
||||||
@@ -156,9 +155,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ChannelStatusUnknown = 0
|
ChannelStatusUnknown = 0
|
||||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||||
ChannelStatusDisabled = 2 // also don't use 0
|
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||||
|
ChannelStatusAutoDisabled = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -185,30 +185,32 @@ const (
|
|||||||
ChannelTypeOpenRouter = 20
|
ChannelTypeOpenRouter = 20
|
||||||
ChannelTypeAIProxyLibrary = 21
|
ChannelTypeAIProxyLibrary = 21
|
||||||
ChannelTypeFastGPT = 22
|
ChannelTypeFastGPT = 22
|
||||||
|
ChannelTypeTencent = 23
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
"", // 0
|
"", // 0
|
||||||
"https://api.openai.com", // 1
|
"https://api.openai.com", // 1
|
||||||
"https://oa.api2d.net", // 2
|
"https://oa.api2d.net", // 2
|
||||||
"", // 3
|
"", // 3
|
||||||
"https://api.closeai-proxy.xyz", // 4
|
"https://api.closeai-proxy.xyz", // 4
|
||||||
"https://api.openai-sb.com", // 5
|
"https://api.openai-sb.com", // 5
|
||||||
"https://api.openaimax.com", // 6
|
"https://api.openaimax.com", // 6
|
||||||
"https://api.ohmygpt.com", // 7
|
"https://api.ohmygpt.com", // 7
|
||||||
"", // 8
|
"", // 8
|
||||||
"https://api.caipacity.com", // 9
|
"https://api.caipacity.com", // 9
|
||||||
"https://api.aiproxy.io", // 10
|
"https://api.aiproxy.io", // 10
|
||||||
"", // 11
|
"", // 11
|
||||||
"https://api.api2gpt.com", // 12
|
"https://api.api2gpt.com", // 12
|
||||||
"https://api.aigc2d.com", // 13
|
"https://api.aigc2d.com", // 13
|
||||||
"https://api.anthropic.com", // 14
|
"https://api.anthropic.com", // 14
|
||||||
"https://aip.baidubce.com", // 15
|
"https://aip.baidubce.com", // 15
|
||||||
"https://open.bigmodel.cn", // 16
|
"https://open.bigmodel.cn", // 16
|
||||||
"https://dashscope.aliyuncs.com", // 17
|
"https://dashscope.aliyuncs.com", // 17
|
||||||
"", // 18
|
"", // 18
|
||||||
"https://ai.360.cn", // 19
|
"https://ai.360.cn", // 19
|
||||||
"https://openrouter.ai/api", // 20
|
"https://openrouter.ai/api", // 20
|
||||||
"https://api.aiproxy.io", // 21
|
"https://api.aiproxy.io", // 21
|
||||||
"https://fastgpt.run/api/openapi", // 22
|
"https://fastgpt.run/api/openapi", // 22
|
||||||
|
"https://hunyuan.cloud.tencent.com", //23
|
||||||
}
|
}
|
||||||
|
|||||||
6
common/database.go
Normal file
6
common/database.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
var UsingSQLite = false
|
||||||
|
var UsingPostgreSQL = false
|
||||||
|
|
||||||
|
var SQLitePath = "one-api.db"
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||||
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(requestBody, &v)
|
contentType := c.Request.Header.Get("Content-Type")
|
||||||
|
if strings.HasPrefix(contentType, "application/json") {
|
||||||
|
err = json.Unmarshal(requestBody, &v)
|
||||||
|
} else {
|
||||||
|
// skip for now
|
||||||
|
// TODO: someday non json request have variant model, we will need to implementation this
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,8 +3,32 @@ package common
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var DalleSizeRatios = map[string]map[string]float64{
|
||||||
|
"dall-e-2": {
|
||||||
|
"256x256": 1,
|
||||||
|
"512x512": 1.125,
|
||||||
|
"1024x1024": 1.25,
|
||||||
|
},
|
||||||
|
"dall-e-3": {
|
||||||
|
"1024x1024": 1,
|
||||||
|
"1024x1792": 2,
|
||||||
|
"1792x1024": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var DalleGenerationImageAmounts = map[string][2]int{
|
||||||
|
"dall-e-2": {1, 10},
|
||||||
|
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
|
||||||
|
}
|
||||||
|
|
||||||
|
var DalleImagePromptLengthLimitations = map[string]int{
|
||||||
|
"dall-e-2": 1000,
|
||||||
|
"dall-e-3": 4000,
|
||||||
|
}
|
||||||
|
|
||||||
// ModelRatio
|
// ModelRatio
|
||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||||
@@ -19,12 +43,15 @@ var ModelRatio = map[string]float64{
|
|||||||
"gpt-4-32k": 30,
|
"gpt-4-32k": 30,
|
||||||
"gpt-4-32k-0314": 30,
|
"gpt-4-32k-0314": 30,
|
||||||
"gpt-4-32k-0613": 30,
|
"gpt-4-32k-0613": 30,
|
||||||
|
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||||
|
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
||||||
"gpt-3.5-turbo-0301": 0.75,
|
"gpt-3.5-turbo-0301": 0.75,
|
||||||
"gpt-3.5-turbo-0613": 0.75,
|
"gpt-3.5-turbo-0613": 0.75,
|
||||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||||
|
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||||
"text-ada-001": 0.2,
|
"text-ada-001": 0.2,
|
||||||
"text-babbage-001": 0.25,
|
"text-babbage-001": 0.25,
|
||||||
"text-curie-001": 1,
|
"text-curie-001": 1,
|
||||||
@@ -32,7 +59,11 @@ var ModelRatio = map[string]float64{
|
|||||||
"text-davinci-003": 10,
|
"text-davinci-003": 10,
|
||||||
"text-davinci-edit-001": 10,
|
"text-davinci-edit-001": 10,
|
||||||
"code-davinci-edit-001": 10,
|
"code-davinci-edit-001": 10,
|
||||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||||
|
"tts-1": 7.5, // $0.015 / 1K characters
|
||||||
|
"tts-1-1106": 7.5,
|
||||||
|
"tts-1-hd": 15, // $0.030 / 1K characters
|
||||||
|
"tts-1-hd-1106": 15,
|
||||||
"davinci": 10,
|
"davinci": 10,
|
||||||
"curie": 10,
|
"curie": 10,
|
||||||
"babbage": 10,
|
"babbage": 10,
|
||||||
@@ -41,13 +72,16 @@ var ModelRatio = map[string]float64{
|
|||||||
"text-search-ada-doc-001": 10,
|
"text-search-ada-doc-001": 10,
|
||||||
"text-moderation-stable": 0.1,
|
"text-moderation-stable": 0.1,
|
||||||
"text-moderation-latest": 0.1,
|
"text-moderation-latest": 0.1,
|
||||||
"dall-e": 8,
|
"dall-e-2": 8, // $0.016 - $0.020 / image
|
||||||
|
"dall-e-3": 20, // $0.040 - $0.120 / image
|
||||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||||
|
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
|
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||||
@@ -59,7 +93,7 @@ var ModelRatio = map[string]float64{
|
|||||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||||
"360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens
|
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
func ModelRatio2JSONString() string {
|
||||||
@@ -86,9 +120,24 @@ func GetModelRatio(name string) float64 {
|
|||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
if strings.HasPrefix(name, "gpt-3.5") {
|
if strings.HasPrefix(name, "gpt-3.5") {
|
||||||
|
if strings.HasSuffix(name, "1106") {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
|
||||||
|
// TODO: clear this after 2023-12-11
|
||||||
|
now := time.Now()
|
||||||
|
// https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||||
|
// if after 2023-12-11, use 2
|
||||||
|
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
return 1.333333
|
return 1.333333
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "gpt-4") {
|
if strings.HasPrefix(name, "gpt-4") {
|
||||||
|
if strings.HasSuffix(name, "preview") {
|
||||||
|
return 3
|
||||||
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(name, "claude-instant-1") {
|
if strings.HasPrefix(name, "claude-instant-1") {
|
||||||
|
|||||||
@@ -199,3 +199,11 @@ func GetOrDefault(env string, defaultValue int) int {
|
|||||||
func MessageWithRequestId(message string, id string) string {
|
func MessageWithRequestId(message string, id string) string {
|
||||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func String2Int(str string) int {
|
||||||
|
num, err := strconv.Atoi(str)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return num
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,13 +5,15 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||||
@@ -42,14 +44,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
}
|
}
|
||||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||||
if channel.Type == common.ChannelTypeAzure {
|
if channel.Type == common.ChannelTypeAzure {
|
||||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||||
} else {
|
} else {
|
||||||
if channel.GetBaseURL() != "" {
|
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
||||||
requestURL = channel.GetBaseURL()
|
requestURL = baseURL
|
||||||
}
|
}
|
||||||
requestURL += "/v1/chat/completions"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
||||||
|
}
|
||||||
jsonData, err := json.Marshal(request)
|
jsonData, err := json.Marshal(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
@@ -70,10 +72,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
var response TextResponse
|
var response TextResponse
|
||||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
body, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
err = json.Unmarshal(body, &response)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
||||||
|
}
|
||||||
if response.Usage.CompletionTokens == 0 {
|
if response.Usage.CompletionTokens == 0 {
|
||||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||||
}
|
}
|
||||||
@@ -141,7 +147,7 @@ func disableChannel(channelId int, channelName string, reason string) {
|
|||||||
if common.RootUserEmail == "" {
|
if common.RootUserEmail == "" {
|
||||||
common.RootUserEmail = model.GetRootUserEmail()
|
common.RootUserEmail = model.GetRootUserEmail()
|
||||||
}
|
}
|
||||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
|
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||||
|
|||||||
@@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteDisabledChannel(c *gin.Context) {
|
||||||
|
rows, err := model.DeleteDisabledChannel()
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"message": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"message": "",
|
||||||
|
"data": rows,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func UpdateChannel(c *gin.Context) {
|
func UpdateChannel(c *gin.Context) {
|
||||||
channel := model.Channel{}
|
channel := model.Channel{}
|
||||||
err := c.ShouldBindJSON(&channel)
|
err := c.ShouldBindJSON(&channel)
|
||||||
|
|||||||
@@ -55,12 +55,21 @@ func init() {
|
|||||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
openAIModels = []OpenAIModels{
|
openAIModels = []OpenAIModels{
|
||||||
{
|
{
|
||||||
Id: "dall-e",
|
Id: "dall-e-2",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "openai",
|
OwnedBy: "openai",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "dall-e",
|
Root: "dall-e-2",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "dall-e-3",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "dall-e-3",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -72,6 +81,42 @@ func init() {
|
|||||||
Root: "whisper-1",
|
Root: "whisper-1",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "tts-1",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "tts-1",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "tts-1-1106",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "tts-1-1106",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "tts-1-hd",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "tts-1-hd",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "tts-1-hd-1106",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "tts-1-hd-1106",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "gpt-3.5-turbo",
|
Id: "gpt-3.5-turbo",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -117,6 +162,15 @@ func init() {
|
|||||||
Root: "gpt-3.5-turbo-16k-0613",
|
Root: "gpt-3.5-turbo-16k-0613",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "gpt-3.5-turbo-1106",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1699593571,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "gpt-3.5-turbo-1106",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "gpt-3.5-turbo-instruct",
|
Id: "gpt-3.5-turbo-instruct",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -180,6 +234,24 @@ func init() {
|
|||||||
Root: "gpt-4-32k-0613",
|
Root: "gpt-4-32k-0613",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "gpt-4-1106-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1699593571,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "gpt-4-1106-preview",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Id: "gpt-4-vision-preview",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1699593571,
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "gpt-4-vision-preview",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "text-embedding-ada-002",
|
Id: "text-embedding-ada-002",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -274,7 +346,7 @@ func init() {
|
|||||||
Id: "claude-instant-1",
|
Id: "claude-instant-1",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "anturopic",
|
OwnedBy: "anthropic",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "claude-instant-1",
|
Root: "claude-instant-1",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
@@ -283,7 +355,7 @@ func init() {
|
|||||||
Id: "claude-2",
|
Id: "claude-2",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "anturopic",
|
OwnedBy: "anthropic",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "claude-2",
|
Root: "claude-2",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
@@ -306,6 +378,15 @@ func init() {
|
|||||||
Root: "ERNIE-Bot-turbo",
|
Root: "ERNIE-Bot-turbo",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "ERNIE-Bot-4",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "baidu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "ERNIE-Bot-4",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "Embedding-V1",
|
Id: "Embedding-V1",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -324,6 +405,15 @@ func init() {
|
|||||||
Root: "PaLM-2",
|
Root: "PaLM-2",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Id: "chatglm_turbo",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1677649963,
|
||||||
|
OwnedBy: "zhipu",
|
||||||
|
Permission: permission,
|
||||||
|
Root: "chatglm_turbo",
|
||||||
|
Parent: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Id: "chatglm_pro",
|
Id: "chatglm_pro",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -424,12 +514,12 @@ func init() {
|
|||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: "360GPT_S2_V9.4",
|
Id: "hunyuan",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
Created: 1677649963,
|
Created: 1677649963,
|
||||||
OwnedBy: "360",
|
OwnedBy: "tencent",
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Root: "360GPT_S2_V9.4",
|
Root: "hunyuan",
|
||||||
Parent: nil,
|
Parent: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func UpdateOption(c *gin.Context) {
|
|||||||
if option.Value == "true" && common.GitHubClientId == "" {
|
if option.Value == "true" && common.GitHubClientId == "" {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!",
|
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct {
|
|||||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||||
query := ""
|
query := ""
|
||||||
if len(request.Messages) != 0 {
|
if len(request.Messages) != 0 {
|
||||||
query = request.Messages[len(request.Messages)-1].Content
|
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||||
}
|
}
|
||||||
return &AIProxyLibraryRequest{
|
return &AIProxyLibraryRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
|
|||||||
@@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
|||||||
message := request.Messages[i]
|
message := request.Messages[i]
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
messages = append(messages, AliMessage{
|
messages = append(messages, AliMessage{
|
||||||
User: message.Content,
|
User: message.StringContent(),
|
||||||
Bot: "Okay",
|
Bot: "Okay",
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
if i == len(request.Messages)-1 {
|
if i == len(request.Messages)-1 {
|
||||||
prompt = message.Content
|
prompt = message.StringContent()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
messages = append(messages, AliMessage{
|
messages = append(messages, AliMessage{
|
||||||
User: message.Content,
|
User: message.StringContent(),
|
||||||
Bot: request.Messages[i+1].Content,
|
Bot: request.Messages[i+1].StringContent(),
|
||||||
})
|
})
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
@@ -21,6 +20,22 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
userId := c.GetInt("id")
|
userId := c.GetInt("id")
|
||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
|
tokenName := c.GetString("token_name")
|
||||||
|
|
||||||
|
var ttsRequest TextToSpeechRequest
|
||||||
|
if relayMode == RelayModeAudioSpeech {
|
||||||
|
// Read JSON
|
||||||
|
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
||||||
|
// Check if JSON is valid
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "invalid_json", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
audioModel = ttsRequest.Model
|
||||||
|
// Check if text is too long 4096
|
||||||
|
if len(ttsRequest.Input) > 4096 {
|
||||||
|
return errorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
modelRatio := common.GetModelRatio(audioModel)
|
modelRatio := common.GetModelRatio(audioModel)
|
||||||
@@ -31,19 +46,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
|
||||||
if err != nil {
|
quota := 0
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
// Check if user quota is enough
|
||||||
}
|
if relayMode == RelayModeAudioSpeech {
|
||||||
if userQuota > 100*preConsumedQuota {
|
quota = int(float64(len(ttsRequest.Input)) * modelRatio * groupRatio)
|
||||||
// in this case, we do not pre-consume quota
|
if quota > userQuota {
|
||||||
// because the user has enough quota
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
preConsumedQuota = 0
|
}
|
||||||
}
|
} else {
|
||||||
if preConsumedQuota > 0 {
|
if userQuota-preConsumedQuota < 0 {
|
||||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
if userQuota > 100*preConsumedQuota {
|
||||||
|
// in this case, we do not pre-consume quota
|
||||||
|
// because the user has enough quota
|
||||||
|
preConsumedQuota = 0
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,12 +90,11 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
|
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
requestBody := c.Request.Body
|
requestBody := c.Request.Body
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
@@ -91,47 +118,32 @@ func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
var audioResponse AudioResponse
|
|
||||||
|
|
||||||
defer func(ctx context.Context) {
|
if relayMode == RelayModeAudioSpeech {
|
||||||
go func() {
|
defer func(ctx context.Context) {
|
||||||
quota := countTokenText(audioResponse.Text, audioModel)
|
go postConsumeQuota(ctx, tokenId, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||||
|
}(c.Request.Context())
|
||||||
|
} else {
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
var whisperResponse WhisperResponse
|
||||||
|
err = json.Unmarshal(responseBody, &whisperResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
defer func(ctx context.Context) {
|
||||||
|
quota := countTokenText(whisperResponse.Text, audioModel)
|
||||||
quotaDelta := quota - preConsumedQuota
|
quotaDelta := quota - preConsumedQuota
|
||||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
go postConsumeQuota(ctx, tokenId, quotaDelta, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||||
if err != nil {
|
}(c.Request.Context())
|
||||||
common.SysError("error consuming token remain quota: " + err.Error())
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
}
|
|
||||||
err = model.CacheUpdateUserQuota(userId)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error update user quota cache: " + err.Error())
|
|
||||||
}
|
|
||||||
if quota != 0 {
|
|
||||||
tokenName := c.GetString("token_name")
|
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
|
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
|
||||||
channelId := c.GetInt("channel_id")
|
|
||||||
model.UpdateChannelUsedQuota(channelId, quota)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}(c.Request.Context())
|
|
||||||
|
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(responseBody, &audioResponse)
|
|
||||||
if err != nil {
|
|
||||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
|
||||||
|
|
||||||
for k, v := range resp.Header {
|
for k, v := range resp.Header {
|
||||||
c.Writer.Header().Set(k, v[0])
|
c.Writer.Header().Set(k, v[0])
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
messages = append(messages, BaiduMessage{
|
messages = append(messages, BaiduMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
messages = append(messages, BaiduMessage{
|
messages = append(messages, BaiduMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
@@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
|||||||
} else {
|
} else {
|
||||||
messages = append(messages, BaiduMessage{
|
messages = append(messages, BaiduMessage{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,20 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isWithinRange(element string, value int) bool {
|
||||||
|
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
min := common.DalleGenerationImageAmounts[element][0]
|
||||||
|
max := common.DalleGenerationImageAmounts[element][1]
|
||||||
|
|
||||||
|
return value >= min && value <= max
|
||||||
|
}
|
||||||
|
|
||||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||||
imageModel := "dall-e"
|
imageModel := "dall-e-2"
|
||||||
|
imageSize := "1024x1024"
|
||||||
|
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
@@ -32,19 +44,44 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Size validation
|
||||||
|
if imageRequest.Size != "" {
|
||||||
|
imageSize = imageRequest.Size
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model validation
|
||||||
|
if imageRequest.Model != "" {
|
||||||
|
imageModel = imageRequest.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
|
||||||
|
|
||||||
|
// Check if model is supported
|
||||||
|
if hasValidSize {
|
||||||
|
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
|
||||||
|
if imageSize == "1024x1024" {
|
||||||
|
imageCostRatio *= 2
|
||||||
|
} else {
|
||||||
|
imageCostRatio *= 1.5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
// Prompt validation
|
// Prompt validation
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
return errorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not "256x256", "512x512", or "1024x1024"
|
// Check prompt length
|
||||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
|
||||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
|
return errorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// N should between 1 and 10
|
// Number of generated images validation
|
||||||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
if isWithinRange(imageModel, imageRequest.N) == false {
|
||||||
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
return errorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
@@ -61,16 +98,12 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
isModelMapped = true
|
isModelMapped = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
|
|
||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
|
||||||
|
|
||||||
var requestBody io.Reader
|
var requestBody io.Reader
|
||||||
if isModelMapped {
|
if isModelMapped {
|
||||||
jsonStr, err := json.Marshal(imageRequest)
|
jsonStr, err := json.Marshal(imageRequest)
|
||||||
@@ -87,19 +120,10 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
|||||||
ratio := modelRatio * groupRatio
|
ratio := modelRatio * groupRatio
|
||||||
userQuota, err := model.CacheGetUserQuota(userId)
|
userQuota, err := model.CacheGetUserQuota(userId)
|
||||||
|
|
||||||
sizeRatio := 1.0
|
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
||||||
// Size
|
|
||||||
if imageRequest.Size == "256x256" {
|
|
||||||
sizeRatio = 1
|
|
||||||
} else if imageRequest.Size == "512x512" {
|
|
||||||
sizeRatio = 1.125
|
|
||||||
} else if imageRequest.Size == "1024x1024" {
|
|
||||||
sizeRatio = 1.25
|
|
||||||
}
|
|
||||||
quota := int(ratio*sizeRatio*1000) * imageRequest.N
|
|
||||||
|
|
||||||
if consumeQuota && userQuota-quota < 0 {
|
if consumeQuota && userQuota-quota < 0 {
|
||||||
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp
|
|||||||
if textResponse.Usage.TotalTokens == 0 {
|
if textResponse.Usage.TotalTokens == 0 {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range textResponse.Choices {
|
||||||
completionTokens += countTokenText(choice.Message.Content, model)
|
completionTokens += countTokenText(choice.Message.StringContent(), model)
|
||||||
}
|
}
|
||||||
textResponse.Usage = Usage{
|
textResponse.Usage = Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
|||||||
}
|
}
|
||||||
for _, message := range textRequest.Messages {
|
for _, message := range textRequest.Messages {
|
||||||
palmMessage := PaLMChatMessage{
|
palmMessage := PaLMChatMessage{
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
}
|
}
|
||||||
if message.Role == "user" {
|
if message.Role == "user" {
|
||||||
palmMessage.Author = "0"
|
palmMessage.Author = "0"
|
||||||
|
|||||||
287
controller/relay-tencent.go
Normal file
287
controller/relay-tencent.go
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
package controller
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// https://cloud.tencent.com/document/product/1729/97732
|
||||||
|
|
||||||
|
type TencentMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatRequest struct {
|
||||||
|
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||||
|
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||||
|
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||||
|
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||||
|
Timestamp int64 `json:"timestamp"`
|
||||||
|
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||||
|
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||||
|
Expired int64 `json:"expired"`
|
||||||
|
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||||
|
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||||
|
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||||
|
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||||
|
Temperature float64 `json:"temperature"`
|
||||||
|
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||||
|
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||||
|
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||||
|
TopP float64 `json:"top_p"`
|
||||||
|
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||||
|
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||||
|
Stream int `json:"stream"`
|
||||||
|
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||||
|
// 输入 content 总数最大支持 3000 token。
|
||||||
|
Messages []TencentMessage `json:"messages"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentError struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentUsage struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentResponseChoices struct {
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||||
|
Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||||
|
}
|
||||||
|
|
||||||
|
type TencentChatResponse struct {
|
||||||
|
Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果
|
||||||
|
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||||
|
Id string `json:"id,omitempty"` // 会话 id
|
||||||
|
Usage Usage `json:"usage,omitempty"` // token 数量
|
||||||
|
Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||||
|
Note string `json:"note,omitempty"` // 注释
|
||||||
|
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest {
|
||||||
|
messages := make([]TencentMessage, 0, len(request.Messages))
|
||||||
|
for i := 0; i < len(request.Messages); i++ {
|
||||||
|
message := request.Messages[i]
|
||||||
|
if message.Role == "system" {
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Role: "user",
|
||||||
|
Content: message.StringContent(),
|
||||||
|
})
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Okay",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
messages = append(messages, TencentMessage{
|
||||||
|
Content: message.StringContent(),
|
||||||
|
Role: message.Role,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
stream := 0
|
||||||
|
if request.Stream {
|
||||||
|
stream = 1
|
||||||
|
}
|
||||||
|
return &TencentChatRequest{
|
||||||
|
Timestamp: common.GetTimestamp(),
|
||||||
|
Expired: common.GetTimestamp() + 24*60*60,
|
||||||
|
QueryID: common.GetUUID(),
|
||||||
|
Temperature: request.Temperature,
|
||||||
|
TopP: request.TopP,
|
||||||
|
Stream: stream,
|
||||||
|
Messages: messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse {
|
||||||
|
fullTextResponse := OpenAITextResponse{
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Usage: response.Usage,
|
||||||
|
}
|
||||||
|
if len(response.Choices) > 0 {
|
||||||
|
choice := OpenAITextResponseChoice{
|
||||||
|
Index: 0,
|
||||||
|
Message: Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: response.Choices[0].Messages.Content,
|
||||||
|
},
|
||||||
|
FinishReason: response.Choices[0].FinishReason,
|
||||||
|
}
|
||||||
|
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||||
|
}
|
||||||
|
return &fullTextResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse {
|
||||||
|
response := ChatCompletionsStreamResponse{
|
||||||
|
Object: "chat.completion.chunk",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: "tencent-hunyuan",
|
||||||
|
}
|
||||||
|
if len(TencentResponse.Choices) > 0 {
|
||||||
|
var choice ChatCompletionsStreamResponseChoice
|
||||||
|
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||||
|
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||||
|
choice.FinishReason = &stopFinishReason
|
||||||
|
}
|
||||||
|
response.Choices = append(response.Choices, choice)
|
||||||
|
}
|
||||||
|
return &response
|
||||||
|
}
|
||||||
|
|
||||||
|
func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||||
|
var responseText string
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
if len(data) < 5 { // ignore blank line or wrong format
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if data[:5] != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data = data[5:]
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
setEventStreamHeaders(c)
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
var TencentResponse TencentChatResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||||
|
if len(response.Choices) != 0 {
|
||||||
|
responseText += response.Choices[0].Delta.Content
|
||||||
|
}
|
||||||
|
jsonResponse, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
err := resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||||
|
}
|
||||||
|
return nil, responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||||
|
var TencentResponse TencentChatResponse
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
if TencentResponse.Error.Code != 0 {
|
||||||
|
return &OpenAIErrorWithStatusCode{
|
||||||
|
OpenAIError: OpenAIError{
|
||||||
|
Message: TencentResponse.Error.Message,
|
||||||
|
Code: TencentResponse.Error.Code,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||||
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &fullTextResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||||
|
parts := strings.Split(config, "|")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
err = errors.New("invalid tencent config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
secretId = parts[1]
|
||||||
|
secretKey = parts[2]
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTencentSign(req TencentChatRequest, secretKey string) string {
|
||||||
|
params := make([]string, 0)
|
||||||
|
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||||
|
params = append(params, "secret_id="+req.SecretId)
|
||||||
|
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||||
|
params = append(params, "query_id="+req.QueryID)
|
||||||
|
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||||
|
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||||
|
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||||
|
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||||
|
|
||||||
|
var messageStr string
|
||||||
|
for _, msg := range req.Messages {
|
||||||
|
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||||
|
}
|
||||||
|
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||||
|
params = append(params, "messages=["+messageStr+"]")
|
||||||
|
|
||||||
|
sort.Sort(sort.StringSlice(params))
|
||||||
|
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||||
|
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||||
|
signURL := url
|
||||||
|
mac.Write([]byte(signURL))
|
||||||
|
sign := mac.Sum([]byte(nil))
|
||||||
|
return base64.StdEncoding.EncodeToString(sign)
|
||||||
|
}
|
||||||
@@ -6,13 +6,15 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -24,13 +26,21 @@ const (
|
|||||||
APITypeAli
|
APITypeAli
|
||||||
APITypeXunfei
|
APITypeXunfei
|
||||||
APITypeAIProxyLibrary
|
APITypeAIProxyLibrary
|
||||||
|
APITypeTencent
|
||||||
)
|
)
|
||||||
|
|
||||||
var httpClient *http.Client
|
var httpClient *http.Client
|
||||||
var impatientHTTPClient *http.Client
|
var impatientHTTPClient *http.Client
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
httpClient = &http.Client{}
|
if common.RelayTimeout == 0 {
|
||||||
|
httpClient = &http.Client{}
|
||||||
|
} else {
|
||||||
|
httpClient = &http.Client{
|
||||||
|
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impatientHTTPClient = &http.Client{
|
impatientHTTPClient = &http.Client{
|
||||||
Timeout: 5 * time.Second,
|
Timeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
@@ -109,13 +119,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
apiType = APITypeXunfei
|
apiType = APITypeXunfei
|
||||||
case common.ChannelTypeAIProxyLibrary:
|
case common.ChannelTypeAIProxyLibrary:
|
||||||
apiType = APITypeAIProxyLibrary
|
apiType = APITypeAIProxyLibrary
|
||||||
|
case common.ChannelTypeTencent:
|
||||||
|
apiType = APITypeTencent
|
||||||
}
|
}
|
||||||
baseURL := common.ChannelBaseURLs[channelType]
|
baseURL := common.ChannelBaseURLs[channelType]
|
||||||
requestURL := c.Request.URL.String()
|
requestURL := c.Request.URL.String()
|
||||||
if c.GetString("base_url") != "" {
|
if c.GetString("base_url") != "" {
|
||||||
baseURL = c.GetString("base_url")
|
baseURL = c.GetString("base_url")
|
||||||
}
|
}
|
||||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case APITypeOpenAI:
|
case APITypeOpenAI:
|
||||||
if channelType == common.ChannelTypeAzure {
|
if channelType == common.ChannelTypeAzure {
|
||||||
@@ -135,7 +147,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
model_ = strings.TrimSuffix(model_, "-0301")
|
model_ = strings.TrimSuffix(model_, "-0301")
|
||||||
model_ = strings.TrimSuffix(model_, "-0314")
|
model_ = strings.TrimSuffix(model_, "-0314")
|
||||||
model_ = strings.TrimSuffix(model_, "-0613")
|
model_ = strings.TrimSuffix(model_, "-0613")
|
||||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
|
||||||
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||||
|
fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType)
|
||||||
}
|
}
|
||||||
case APITypeClaude:
|
case APITypeClaude:
|
||||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||||
@@ -148,6 +162,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||||
case "ERNIE-Bot-turbo":
|
case "ERNIE-Bot-turbo":
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||||
|
case "ERNIE-Bot-4":
|
||||||
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||||
case "BLOOMZ-7B":
|
case "BLOOMZ-7B":
|
||||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||||
case "Embedding-V1":
|
case "Embedding-V1":
|
||||||
@@ -179,6 +195,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if relayMode == RelayModeEmbeddings {
|
if relayMode == RelayModeEmbeddings {
|
||||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
}
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||||
case APITypeAIProxyLibrary:
|
case APITypeAIProxyLibrary:
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||||
}
|
}
|
||||||
@@ -204,6 +222,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
if userQuota-preConsumedQuota < 0 {
|
||||||
|
return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||||
|
}
|
||||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||||
@@ -282,6 +303,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
requestBody = bytes.NewBuffer(jsonStr)
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
|
case APITypeTencent:
|
||||||
|
apiKey := c.Request.Header.Get("Authorization")
|
||||||
|
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||||
|
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
tencentRequest := requestOpenAI2Tencent(textRequest)
|
||||||
|
tencentRequest.AppId = appId
|
||||||
|
tencentRequest.SecretId = secretId
|
||||||
|
jsonStr, err := json.Marshal(tencentRequest)
|
||||||
|
if err != nil {
|
||||||
|
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
sign := getTencentSign(*tencentRequest, secretKey)
|
||||||
|
c.Request.Header.Set("Authorization", sign)
|
||||||
|
requestBody = bytes.NewBuffer(jsonStr)
|
||||||
case APITypeAIProxyLibrary:
|
case APITypeAIProxyLibrary:
|
||||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||||
@@ -329,11 +367,18 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
if textRequest.Stream {
|
if textRequest.Stream {
|
||||||
req.Header.Set("X-DashScope-SSE", "enable")
|
req.Header.Set("X-DashScope-SSE", "enable")
|
||||||
}
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
req.Header.Set("Authorization", apiKey)
|
||||||
|
case APITypePaLM:
|
||||||
|
// do not set Authorization header
|
||||||
default:
|
default:
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||||
|
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
}
|
||||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||||
resp, err = httpClient.Do(req)
|
resp, err = httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -374,9 +419,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||||
promptTokens = textResponse.Usage.PromptTokens
|
promptTokens = textResponse.Usage.PromptTokens
|
||||||
completionTokens = textResponse.Usage.CompletionTokens
|
completionTokens = textResponse.Usage.CompletionTokens
|
||||||
|
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
|
||||||
quota = int(float64(quota) * ratio)
|
|
||||||
if ratio != 0 && quota <= 0 {
|
if ratio != 0 && quota <= 0 {
|
||||||
quota = 1
|
quota = 1
|
||||||
}
|
}
|
||||||
@@ -581,6 +624,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
case APITypeTencent:
|
||||||
|
if isStream {
|
||||||
|
err, responseText := tencentStreamHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
textResponse.Usage.PromptTokens = promptTokens
|
||||||
|
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
err, usage := tencentHandler(c, resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if usage != nil {
|
||||||
|
textResponse.Usage = *usage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,52 +1,64 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/pkoukk/tiktoken-go"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
)
|
)
|
||||||
|
|
||||||
var stopFinishReason = "stop"
|
var stopFinishReason = "stop"
|
||||||
|
|
||||||
|
// tokenEncoderMap won't grow after initialization
|
||||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||||
|
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||||
|
|
||||||
func InitTokenEncoders() {
|
func InitTokenEncoders() {
|
||||||
common.SysLog("initializing token encoders")
|
common.SysLog("initializing token encoders")
|
||||||
fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
|
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||||
|
}
|
||||||
|
defaultTokenEncoder = gpt35TokenEncoder
|
||||||
|
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||||
|
if err != nil {
|
||||||
|
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||||
}
|
}
|
||||||
for model, _ := range common.ModelRatio {
|
for model, _ := range common.ModelRatio {
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
if strings.HasPrefix(model, "gpt-3.5") {
|
||||||
if err != nil {
|
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||||
common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
|
} else if strings.HasPrefix(model, "gpt-4") {
|
||||||
tokenEncoderMap[model] = fallbackTokenEncoder
|
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||||
continue
|
} else {
|
||||||
|
tokenEncoderMap[model] = nil
|
||||||
}
|
}
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
|
||||||
}
|
}
|
||||||
common.SysLog("token encoders initialized")
|
common.SysLog("token encoders initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||||
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
|
tokenEncoder, ok := tokenEncoderMap[model]
|
||||||
|
if ok && tokenEncoder != nil {
|
||||||
return tokenEncoder
|
return tokenEncoder
|
||||||
}
|
}
|
||||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
if ok {
|
||||||
if err != nil {
|
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
|
||||||
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
|
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||||
|
tokenEncoder = defaultTokenEncoder
|
||||||
}
|
}
|
||||||
|
tokenEncoderMap[model] = tokenEncoder
|
||||||
|
return tokenEncoder
|
||||||
}
|
}
|
||||||
tokenEncoderMap[model] = tokenEncoder
|
return defaultTokenEncoder
|
||||||
return tokenEncoder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||||
@@ -75,7 +87,7 @@ func countTokenMessages(messages []Message, model string) int {
|
|||||||
tokenNum := 0
|
tokenNum := 0
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Content)
|
tokenNum += getTokenNum(tokenEncoder, message.StringContent())
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if message.Name != nil {
|
if message.Name != nil {
|
||||||
tokenNum += tokensPerName
|
tokenNum += tokensPerName
|
||||||
@@ -167,3 +179,35 @@ func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIEr
|
|||||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||||
|
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||||
|
|
||||||
|
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||||
|
switch channelType {
|
||||||
|
case common.ChannelTypeOpenAI:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||||
|
case common.ChannelTypeAzure:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullRequestURL
|
||||||
|
}
|
||||||
|
|
||||||
|
func postConsumeQuota(ctx context.Context, tokenId int, quota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||||
|
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error consuming token remain quota: " + err.Error())
|
||||||
|
}
|
||||||
|
err = model.CacheUpdateUserQuota(userId)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error update user quota cache: " + err.Error())
|
||||||
|
}
|
||||||
|
if quota != 0 {
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||||
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, modelName, tokenName, quota, logContent)
|
||||||
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
|
model.UpdateChannelUsedQuota(channelId, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
|||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
messages = append(messages, XunfeiMessage{
|
messages = append(messages, XunfeiMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
messages = append(messages, XunfeiMessage{
|
messages = append(messages, XunfeiMessage{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
@@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
|||||||
} else {
|
} else {
|
||||||
messages = append(messages, XunfeiMessage{
|
messages = append(messages, XunfeiMessage{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -220,6 +220,9 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
|
|||||||
for !stop {
|
for !stop {
|
||||||
select {
|
select {
|
||||||
case xunfeiResponse = <-dataChan:
|
case xunfeiResponse = <-dataChan:
|
||||||
|
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||||
@@ -295,8 +298,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
|
|||||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||||
}
|
}
|
||||||
domain := "general"
|
domain := "general"
|
||||||
if apiVersion == "v2.1" {
|
if apiVersion != "v1.1" {
|
||||||
domain = "generalv2"
|
domain += strings.Split(apiVersion, ".")[0]
|
||||||
}
|
}
|
||||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||||
return domain, authUrl
|
return domain, authUrl
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
|||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
messages = append(messages, ZhipuMessage{
|
messages = append(messages, ZhipuMessage{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
messages = append(messages, ZhipuMessage{
|
messages = append(messages, ZhipuMessage{
|
||||||
Role: "user",
|
Role: "user",
|
||||||
@@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
|||||||
} else {
|
} else {
|
||||||
messages = append(messages, ZhipuMessage{
|
messages = append(messages, ZhipuMessage{
|
||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: message.Content,
|
Content: message.StringContent(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,10 +12,49 @@ import (
|
|||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content any `json:"content"`
|
||||||
Name *string `json:"name,omitempty"`
|
Name *string `json:"name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ImageURL struct {
|
||||||
|
Url string `json:"url,omitempty"`
|
||||||
|
Detail string `json:"detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TextContent struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImageContent struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m Message) StringContent() string {
|
||||||
|
content, ok := m.Content.(string)
|
||||||
|
if ok {
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
contentList, ok := m.Content.([]any)
|
||||||
|
if ok {
|
||||||
|
var contentStr string
|
||||||
|
for _, contentItem := range contentList {
|
||||||
|
contentMap, ok := contentItem.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if contentMap["type"] == "text" {
|
||||||
|
if subStr, ok := contentMap["text"].(string); ok {
|
||||||
|
contentStr += subStr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentStr
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RelayModeUnknown = iota
|
RelayModeUnknown = iota
|
||||||
RelayModeChatCompletions
|
RelayModeChatCompletions
|
||||||
@@ -24,24 +63,37 @@ const (
|
|||||||
RelayModeModerations
|
RelayModeModerations
|
||||||
RelayModeImagesGenerations
|
RelayModeImagesGenerations
|
||||||
RelayModeEdits
|
RelayModeEdits
|
||||||
RelayModeAudio
|
RelayModeAudioSpeech
|
||||||
|
RelayModeAudioTranscription
|
||||||
|
RelayModeAudioTranslation
|
||||||
)
|
)
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/chat
|
// https://platform.openai.com/docs/api-reference/chat
|
||||||
|
|
||||||
|
type ResponseFormat struct {
|
||||||
|
Type string `json:"type,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type GeneralOpenAIRequest struct {
|
type GeneralOpenAIRequest struct {
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
Prompt any `json:"prompt,omitempty"`
|
Prompt any `json:"prompt,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
N int `json:"n,omitempty"`
|
N int `json:"n,omitempty"`
|
||||||
Input any `json:"input,omitempty"`
|
Input any `json:"input,omitempty"`
|
||||||
Instruction string `json:"instruction,omitempty"`
|
Instruction string `json:"instruction,omitempty"`
|
||||||
Size string `json:"size,omitempty"`
|
Size string `json:"size,omitempty"`
|
||||||
Functions any `json:"functions,omitempty"`
|
Functions any `json:"functions,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||||
|
Seed float64 `json:"seed,omitempty"`
|
||||||
|
Tools any `json:"tools,omitempty"`
|
||||||
|
ToolChoice any `json:"tool_choice,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
@@ -77,16 +129,30 @@ type TextRequest struct {
|
|||||||
//Stream bool `json:"stream"`
|
//Stream bool `json:"stream"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
|
||||||
type ImageRequest struct {
|
type ImageRequest struct {
|
||||||
Prompt string `json:"prompt"`
|
Model string `json:"model"`
|
||||||
N int `json:"n"`
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
Size string `json:"size"`
|
N int `json:"n"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
Quality string `json:"quality"`
|
||||||
|
ResponseFormat string `json:"response_format"`
|
||||||
|
Style string `json:"style"`
|
||||||
|
User string `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AudioResponse struct {
|
type WhisperResponse struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TextToSpeechRequest struct {
|
||||||
|
Model string `json:"model" binding:"required"`
|
||||||
|
Input string `json:"input" binding:"required"`
|
||||||
|
Voice string `json:"voice" binding:"required"`
|
||||||
|
Speed float64 `json:"speed"`
|
||||||
|
ResponseFormat string `json:"response_format"`
|
||||||
|
}
|
||||||
|
|
||||||
type Usage struct {
|
type Usage struct {
|
||||||
PromptTokens int `json:"prompt_tokens"`
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
CompletionTokens int `json:"completion_tokens"`
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
@@ -183,14 +249,22 @@ func Relay(c *gin.Context) {
|
|||||||
relayMode = RelayModeImagesGenerations
|
relayMode = RelayModeImagesGenerations
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||||
relayMode = RelayModeEdits
|
relayMode = RelayModeEdits
|
||||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||||
relayMode = RelayModeAudio
|
relayMode = RelayModeAudioSpeech
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||||
|
relayMode = RelayModeAudioTranscription
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
|
relayMode = RelayModeAudioTranslation
|
||||||
}
|
}
|
||||||
var err *OpenAIErrorWithStatusCode
|
var err *OpenAIErrorWithStatusCode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case RelayModeImagesGenerations:
|
case RelayModeImagesGenerations:
|
||||||
err = relayImageHelper(c, relayMode)
|
err = relayImageHelper(c, relayMode)
|
||||||
case RelayModeAudio:
|
case RelayModeAudioSpeech:
|
||||||
|
fallthrough
|
||||||
|
case RelayModeAudioTranslation:
|
||||||
|
fallthrough
|
||||||
|
case RelayModeAudioTranscription:
|
||||||
err = relayAudioHelper(c, relayMode)
|
err = relayAudioHelper(c, relayMode)
|
||||||
default:
|
default:
|
||||||
err = relayTextHelper(c, relayMode)
|
err = relayTextHelper(c, relayMode)
|
||||||
|
|||||||
@@ -9,21 +9,21 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
volumes:
|
volumes:
|
||||||
- ./data:/data
|
- ./data/oneapi:/data
|
||||||
- ./logs:/app/logs
|
- ./logs:/app/logs
|
||||||
environment:
|
environment:
|
||||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
- SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||||
- REDIS_CONN_STRING=redis://redis
|
- REDIS_CONN_STRING=redis://redis
|
||||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||||
- TZ=Asia/Shanghai
|
- TZ=Asia/Shanghai
|
||||||
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行
|
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行
|
||||||
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
|
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
|
||||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
|
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
|
||||||
|
|
||||||
depends_on:
|
depends_on:
|
||||||
- redis
|
- redis
|
||||||
|
- db
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
|
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||||
interval: 30s
|
interval: 30s
|
||||||
timeout: 10s
|
timeout: 10s
|
||||||
retries: 3
|
retries: 3
|
||||||
@@ -32,3 +32,18 @@ services:
|
|||||||
image: redis:latest
|
image: redis:latest
|
||||||
container_name: redis
|
container_name: redis
|
||||||
restart: always
|
restart: always
|
||||||
|
|
||||||
|
db:
|
||||||
|
image: mysql:8.2.0
|
||||||
|
restart: always
|
||||||
|
container_name: mysql
|
||||||
|
volumes:
|
||||||
|
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||||
|
ports:
|
||||||
|
- '3306:3306'
|
||||||
|
environment:
|
||||||
|
TZ: Asia/Shanghai # 设置时区
|
||||||
|
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||||
|
MYSQL_USER: oneapi # 创建专用用户
|
||||||
|
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||||
|
MYSQL_DATABASE: one-api # 自动创建数据库
|
||||||
10
go.mod
10
go.mod
@@ -15,8 +15,9 @@ require (
|
|||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
github.com/pkoukk/tiktoken-go v0.1.5
|
github.com/pkoukk/tiktoken-go v0.1.5
|
||||||
golang.org/x/crypto v0.9.0
|
golang.org/x/crypto v0.14.0
|
||||||
gorm.io/driver/mysql v1.4.3
|
gorm.io/driver/mysql v1.4.3
|
||||||
|
gorm.io/driver/postgres v1.5.2
|
||||||
gorm.io/driver/sqlite v1.4.3
|
gorm.io/driver/sqlite v1.4.3
|
||||||
gorm.io/gorm v1.25.0
|
gorm.io/gorm v1.25.0
|
||||||
)
|
)
|
||||||
@@ -52,10 +53,9 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||||
golang.org/x/arch v0.3.0 // indirect
|
golang.org/x/arch v0.3.0 // indirect
|
||||||
golang.org/x/net v0.10.0 // indirect
|
golang.org/x/net v0.17.0 // indirect
|
||||||
golang.org/x/sys v0.8.0 // indirect
|
golang.org/x/sys v0.13.0 // indirect
|
||||||
golang.org/x/text v0.9.0 // indirect
|
golang.org/x/text v0.13.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/protobuf v1.30.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
gorm.io/driver/postgres v1.5.2 // indirect
|
|
||||||
)
|
)
|
||||||
|
|||||||
17
go.sum
17
go.sum
@@ -150,11 +150,11 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
|||||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
@@ -162,14 +162,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@@ -198,7 +198,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
|
|||||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||||
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
|
|
||||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
||||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != common.ChannelStatusEnabled {
|
if channel.Status != common.ChannelStatusEnabled {
|
||||||
@@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
} else {
|
} else {
|
||||||
// Select a channel for the user
|
// Select a channel for the user
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
var err error
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||||
return
|
return
|
||||||
@@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "dall-e"
|
modelRequest.Model = "dall-e-2"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||||
if modelRequest.Model == "" {
|
if modelRequest.Model == "" {
|
||||||
modelRequest.Model = "whisper-1"
|
modelRequest.Model = "whisper-1"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,10 +15,17 @@ type Ability struct {
|
|||||||
|
|
||||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||||
ability := Ability{}
|
ability := Ability{}
|
||||||
|
groupCol := "`group`"
|
||||||
|
trueVal := "1"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
trueVal = "true"
|
||||||
|
}
|
||||||
|
|
||||||
var err error = nil
|
var err error = nil
|
||||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
|
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||||
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
|
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||||
if common.UsingSQLite {
|
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||||
} else {
|
} else {
|
||||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||||
|
|||||||
@@ -21,14 +21,18 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||||
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
var token Token
|
var token Token
|
||||||
if !common.RedisEnabled {
|
if !common.RedisEnabled {
|
||||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ type Channel struct {
|
|||||||
Key string `json:"key" gorm:"not null;index"`
|
Key string `json:"key" gorm:"not null;index"`
|
||||||
Status int `json:"status" gorm:"default:1"`
|
Status int `json:"status" gorm:"default:1"`
|
||||||
Name string `json:"name" gorm:"index"`
|
Name string `json:"name" gorm:"index"`
|
||||||
Weight int `json:"weight"`
|
Weight *uint `json:"weight" gorm:"default:0"`
|
||||||
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
CreatedTime int64 `json:"created_time" gorm:"bigint"`
|
||||||
TestTime int64 `json:"test_time" gorm:"bigint"`
|
TestTime int64 `json:"test_time" gorm:"bigint"`
|
||||||
ResponseTime int `json:"response_time"` // in milliseconds
|
ResponseTime int `json:"response_time"` // in milliseconds
|
||||||
@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
|
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
|
||||||
return channels, err
|
return channels, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
|||||||
return &channel, err
|
return &channel, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRandomChannel() (*Channel, error) {
|
|
||||||
channel := Channel{}
|
|
||||||
var err error = nil
|
|
||||||
if common.UsingSQLite {
|
|
||||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
|
||||||
} else {
|
|
||||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
|
|
||||||
}
|
|
||||||
return &channel, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func BatchInsertChannels(channels []Channel) error {
|
func BatchInsertChannels(channels []Channel) error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(&channels).Error
|
err = DB.Create(&channels).Error
|
||||||
@@ -176,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) {
|
|||||||
common.SysError("failed to update channel used quota: " + err.Error())
|
common.SysError("failed to update channel used quota: " + err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeleteChannelByStatus(status int64) (int64, error) {
|
||||||
|
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteDisabledChannel() (int64, error) {
|
||||||
|
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||||
|
return result.RowsAffected, result.Error
|
||||||
|
}
|
||||||
|
|||||||
21
model/log.go
21
model/log.go
@@ -8,18 +8,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Log struct {
|
type Log struct {
|
||||||
Id int `json:"id"`
|
Id int `json:"id;index:idx_created_at_id,priority:1"`
|
||||||
UserId int `json:"user_id"`
|
UserId int `json:"user_id" gorm:"index"`
|
||||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
|
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"`
|
||||||
Type int `json:"type" gorm:"index"`
|
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Username string `json:"username" gorm:"index;default:''"`
|
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||||
ModelName string `json:"model_name" gorm:"index;default:''"`
|
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||||
Quota int `json:"quota" gorm:"default:0"`
|
Quota int `json:"quota" gorm:"default:0"`
|
||||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||||
Channel int `json:"channel" gorm:"default:0"`
|
ChannelId int `json:"channel" gorm:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -47,7 +47,6 @@ func RecordLog(userId int, logType int, content string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
|
||||||
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
|
||||||
if !common.LogConsumeEnabled {
|
if !common.LogConsumeEnabled {
|
||||||
@@ -64,7 +63,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
|
|||||||
TokenName: tokenName,
|
TokenName: tokenName,
|
||||||
ModelName: modelName,
|
ModelName: modelName,
|
||||||
Quota: quota,
|
Quota: quota,
|
||||||
Channel: channelId,
|
ChannelId: channelId,
|
||||||
}
|
}
|
||||||
err := DB.Create(log).Error
|
err := DB.Create(log).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -95,7 +94,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
|||||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||||
}
|
}
|
||||||
if channel != 0 {
|
if channel != 0 {
|
||||||
tx = tx.Where("channel = ?", channel)
|
tx = tx.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||||
return logs, err
|
return logs, err
|
||||||
@@ -152,7 +151,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
|||||||
tx = tx.Where("model_name = ?", modelName)
|
tx = tx.Where("model_name = ?", modelName)
|
||||||
}
|
}
|
||||||
if channel != 0 {
|
if channel != 0 {
|
||||||
tx = tx.Where("channel = ?", channel)
|
tx = tx.Where("channel_id = ?", channel)
|
||||||
}
|
}
|
||||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||||
return quota
|
return quota
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ func chooseDB() (*gorm.DB, error) {
|
|||||||
if strings.HasPrefix(dsn, "postgres://") {
|
if strings.HasPrefix(dsn, "postgres://") {
|
||||||
// Use PostgreSQL
|
// Use PostgreSQL
|
||||||
common.SysLog("using PostgreSQL as database")
|
common.SysLog("using PostgreSQL as database")
|
||||||
|
common.UsingPostgreSQL = true
|
||||||
return gorm.Open(postgres.New(postgres.Config{
|
return gorm.Open(postgres.New(postgres.Config{
|
||||||
DSN: dsn,
|
DSN: dsn,
|
||||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||||
@@ -81,6 +82,7 @@ func InitDB() (err error) {
|
|||||||
if !common.IsMasterNode {
|
if !common.IsMasterNode {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
common.SysLog("database migration started")
|
||||||
err = db.AutoMigrate(&Channel{})
|
err = db.AutoMigrate(&Channel{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
|
|||||||
}
|
}
|
||||||
redemption := &Redemption{}
|
redemption := &Redemption{}
|
||||||
|
|
||||||
|
keyCol := "`key`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
keyCol = `"key"`
|
||||||
|
}
|
||||||
|
|
||||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.New("无效的兑换码")
|
return errors.New("无效的兑换码")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -266,7 +266,12 @@ func GetUserEmail(id int) (email string, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetUserGroup(id int) (group string, err error) {
|
func GetUserGroup(id int) (group string, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
|
groupCol := "`group`"
|
||||||
|
if common.UsingPostgreSQL {
|
||||||
|
groupCol = `"group"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||||
return group, err
|
return group, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -309,7 +314,8 @@ func GetRootUserEmail() (email string) {
|
|||||||
|
|
||||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||||
if common.BatchUpdateEnabled {
|
if common.BatchUpdateEnabled {
|
||||||
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
|
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||||
|
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||||
@@ -327,6 +333,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func updateUserUsedQuota(id int, quota int) {
|
||||||
|
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||||
|
map[string]interface{}{
|
||||||
|
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||||
|
},
|
||||||
|
).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update user used quota: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateUserRequestCount(id int, count int) {
|
||||||
|
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("failed to update user request count: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func GetUsernameById(id int) (username string) {
|
func GetUsernameById(id int) (username string) {
|
||||||
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
||||||
return username
|
return username
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
BatchUpdateTypeUserQuota = iota
|
BatchUpdateTypeUserQuota = iota
|
||||||
BatchUpdateTypeTokenQuota
|
BatchUpdateTypeTokenQuota
|
||||||
BatchUpdateTypeUsedQuotaAndRequestCount
|
BatchUpdateTypeUsedQuota
|
||||||
BatchUpdateTypeChannelUsedQuota
|
BatchUpdateTypeChannelUsedQuota
|
||||||
|
BatchUpdateTypeRequestCount
|
||||||
|
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||||
)
|
)
|
||||||
|
|
||||||
var batchUpdateStores []map[int]int
|
var batchUpdateStores []map[int]int
|
||||||
@@ -51,7 +51,7 @@ func batchUpdate() {
|
|||||||
store := batchUpdateStores[i]
|
store := batchUpdateStores[i]
|
||||||
batchUpdateStores[i] = make(map[int]int)
|
batchUpdateStores[i] = make(map[int]int)
|
||||||
batchUpdateLocks[i].Unlock()
|
batchUpdateLocks[i].Unlock()
|
||||||
|
// TODO: maybe we can combine updates with same key?
|
||||||
for key, value := range store {
|
for key, value := range store {
|
||||||
switch i {
|
switch i {
|
||||||
case BatchUpdateTypeUserQuota:
|
case BatchUpdateTypeUserQuota:
|
||||||
@@ -64,8 +64,10 @@ func batchUpdate() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("failed to batch update token quota: " + err.Error())
|
common.SysError("failed to batch update token quota: " + err.Error())
|
||||||
}
|
}
|
||||||
case BatchUpdateTypeUsedQuotaAndRequestCount:
|
case BatchUpdateTypeUsedQuota:
|
||||||
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
|
updateUserUsedQuota(key, value)
|
||||||
|
case BatchUpdateTypeRequestCount:
|
||||||
|
updateUserRequestCount(key, value)
|
||||||
case BatchUpdateTypeChannelUsedQuota:
|
case BatchUpdateTypeChannelUsedQuota:
|
||||||
updateChannelUsedQuota(key, value)
|
updateChannelUsedQuota(key, value)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ func SetApiRouter(router *gin.Engine) {
|
|||||||
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
||||||
channelRoute.POST("/", controller.AddChannel)
|
channelRoute.POST("/", controller.AddChannel)
|
||||||
channelRoute.PUT("/", controller.UpdateChannel)
|
channelRoute.PUT("/", controller.UpdateChannel)
|
||||||
|
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
|
||||||
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
||||||
}
|
}
|
||||||
tokenRoute := apiRouter.Group("/token")
|
tokenRoute := apiRouter.Group("/token")
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||||
relayV1Router.POST("/audio/translations", controller.Relay)
|
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||||
|
relayV1Router.POST("/audio/speech", controller.Relay)
|
||||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||||
|
|||||||
@@ -283,7 +283,9 @@ function App() {
|
|||||||
</Suspense>
|
</Suspense>
|
||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<Route path='*' element={NotFound} />
|
<Route path='*' element={
|
||||||
|
<NotFound />
|
||||||
|
} />
|
||||||
</Routes>
|
</Routes>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import React, { useEffect, useState } from 'react';
|
import React, { useEffect, useState } from 'react';
|
||||||
import {Button, Form, Input, Label, Pagination, Popup, Table} from 'semantic-ui-react';
|
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
|
||||||
import { Link } from 'react-router-dom';
|
import { Link } from 'react-router-dom';
|
||||||
import { API, showError, showInfo, showNotice, showSuccess, timestamp2string } from '../helpers';
|
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers';
|
||||||
|
|
||||||
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
|
||||||
import { renderGroup, renderNumber } from '../helpers/render';
|
import { renderGroup, renderNumber } from '../helpers/render';
|
||||||
@@ -55,6 +55,7 @@ const ChannelsTable = () => {
|
|||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
const [searching, setSearching] = useState(false);
|
const [searching, setSearching] = useState(false);
|
||||||
const [updatingBalance, setUpdatingBalance] = useState(false);
|
const [updatingBalance, setUpdatingBalance] = useState(false);
|
||||||
|
const [showPrompt, setShowPrompt] = useState(shouldShowPrompt("channel-test"));
|
||||||
|
|
||||||
const loadChannels = async (startIdx) => {
|
const loadChannels = async (startIdx) => {
|
||||||
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
const res = await API.get(`/api/channel/?p=${startIdx}`);
|
||||||
@@ -96,7 +97,7 @@ const ChannelsTable = () => {
|
|||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const manageChannel = async (id, action, idx, priority) => {
|
const manageChannel = async (id, action, idx, value) => {
|
||||||
let data = { id };
|
let data = { id };
|
||||||
let res;
|
let res;
|
||||||
switch (action) {
|
switch (action) {
|
||||||
@@ -112,10 +113,20 @@ const ChannelsTable = () => {
|
|||||||
res = await API.put('/api/channel/', data);
|
res = await API.put('/api/channel/', data);
|
||||||
break;
|
break;
|
||||||
case 'priority':
|
case 'priority':
|
||||||
if (priority === '') {
|
if (value === '') {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
data.priority = parseInt(priority);
|
data.priority = parseInt(value);
|
||||||
|
res = await API.put('/api/channel/', data);
|
||||||
|
break;
|
||||||
|
case 'weight':
|
||||||
|
if (value === '') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
data.weight = parseInt(value);
|
||||||
|
if (data.weight < 0) {
|
||||||
|
data.weight = 0;
|
||||||
|
}
|
||||||
res = await API.put('/api/channel/', data);
|
res = await API.put('/api/channel/', data);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -142,9 +153,23 @@ const ChannelsTable = () => {
|
|||||||
return <Label basic color='green'>已启用</Label>;
|
return <Label basic color='green'>已启用</Label>;
|
||||||
case 2:
|
case 2:
|
||||||
return (
|
return (
|
||||||
<Label basic color='red'>
|
<Popup
|
||||||
已禁用
|
trigger={<Label basic color='red'>
|
||||||
</Label>
|
已禁用
|
||||||
|
</Label>}
|
||||||
|
content='本渠道被手动禁用'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
case 3:
|
||||||
|
return (
|
||||||
|
<Popup
|
||||||
|
trigger={<Label basic color='yellow'>
|
||||||
|
已禁用
|
||||||
|
</Label>}
|
||||||
|
content='本渠道被程序自动禁用'
|
||||||
|
basic
|
||||||
|
/>
|
||||||
);
|
);
|
||||||
default:
|
default:
|
||||||
return (
|
return (
|
||||||
@@ -202,7 +227,6 @@ const ChannelsTable = () => {
|
|||||||
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
|
showInfo(`通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
showNotice("当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo 模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。")
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -216,6 +240,17 @@ const ChannelsTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const deleteAllDisabledChannels = async () => {
|
||||||
|
const res = await API.delete(`/api/channel/disabled`);
|
||||||
|
const { success, message, data } = res.data;
|
||||||
|
if (success) {
|
||||||
|
showSuccess(`已删除所有禁用渠道,共计 ${data} 个`);
|
||||||
|
await refresh();
|
||||||
|
} else {
|
||||||
|
showError(message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const updateChannelBalance = async (id, name, idx) => {
|
const updateChannelBalance = async (id, name, idx) => {
|
||||||
const res = await API.get(`/api/channel/update_balance/${id}/`);
|
const res = await API.get(`/api/channel/update_balance/${id}/`);
|
||||||
const { success, message, balance } = res.data;
|
const { success, message, balance } = res.data;
|
||||||
@@ -251,17 +286,15 @@ const ChannelsTable = () => {
|
|||||||
if (channels.length === 0) return;
|
if (channels.length === 0) return;
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
let sortedChannels = [...channels];
|
let sortedChannels = [...channels];
|
||||||
if (typeof sortedChannels[0][key] === 'string') {
|
sortedChannels.sort((a, b) => {
|
||||||
sortedChannels.sort((a, b) => {
|
if (!isNaN(a[key])) {
|
||||||
|
// If the value is numeric, subtract to sort
|
||||||
|
return a[key] - b[key];
|
||||||
|
} else {
|
||||||
|
// If the value is not numeric, sort as strings
|
||||||
return ('' + a[key]).localeCompare(b[key]);
|
return ('' + a[key]).localeCompare(b[key]);
|
||||||
});
|
}
|
||||||
} else {
|
});
|
||||||
sortedChannels.sort((a, b) => {
|
|
||||||
if (a[key] === b[key]) return 0;
|
|
||||||
if (a[key] > b[key]) return -1;
|
|
||||||
if (a[key] < b[key]) return 1;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
if (sortedChannels[0].id === channels[0].id) {
|
if (sortedChannels[0].id === channels[0].id) {
|
||||||
sortedChannels.reverse();
|
sortedChannels.reverse();
|
||||||
}
|
}
|
||||||
@@ -269,6 +302,7 @@ const ChannelsTable = () => {
|
|||||||
setLoading(false);
|
setLoading(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Form onSubmit={searchChannels}>
|
<Form onSubmit={searchChannels}>
|
||||||
@@ -282,7 +316,19 @@ const ChannelsTable = () => {
|
|||||||
onChange={handleKeywordChange}
|
onChange={handleKeywordChange}
|
||||||
/>
|
/>
|
||||||
</Form>
|
</Form>
|
||||||
|
{
|
||||||
|
showPrompt && (
|
||||||
|
<Message onDismiss={() => {
|
||||||
|
setShowPrompt(false);
|
||||||
|
setPromptShown("channel-test");
|
||||||
|
}}>
|
||||||
|
当前版本测试是通过按照 OpenAI API 格式使用 gpt-3.5-turbo
|
||||||
|
模型进行非流式请求实现的,因此测试报错并不一定代表通道不可用,该功能后续会修复。
|
||||||
|
|
||||||
|
另外,OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。
|
||||||
|
</Message>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Table basic compact size='small'>
|
<Table basic compact size='small'>
|
||||||
<Table.Header>
|
<Table.Header>
|
||||||
<Table.Row>
|
<Table.Row>
|
||||||
@@ -343,10 +389,10 @@ const ChannelsTable = () => {
|
|||||||
余额
|
余额
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
<Table.HeaderCell
|
<Table.HeaderCell
|
||||||
style={{ cursor: 'pointer' }}
|
style={{ cursor: 'pointer' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
sortChannel('priority');
|
sortChannel('priority');
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
优先级
|
优先级
|
||||||
</Table.HeaderCell>
|
</Table.HeaderCell>
|
||||||
@@ -390,18 +436,18 @@ const ChannelsTable = () => {
|
|||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
<Popup
|
<Popup
|
||||||
trigger={<Input type="number" defaultValue={channel.priority} onBlur={(event) => {
|
trigger={<Input type='number' defaultValue={channel.priority} onBlur={(event) => {
|
||||||
manageChannel(
|
manageChannel(
|
||||||
channel.id,
|
channel.id,
|
||||||
'priority',
|
'priority',
|
||||||
idx,
|
idx,
|
||||||
event.target.value,
|
event.target.value
|
||||||
);
|
);
|
||||||
}}>
|
}}>
|
||||||
<input style={{maxWidth:'60px'}} />
|
<input style={{ maxWidth: '60px' }} />
|
||||||
</Input>}
|
</Input>}
|
||||||
content='渠道选择优先级,越高越优先'
|
content='渠道选择优先级,越高越优先'
|
||||||
basic
|
basic
|
||||||
/>
|
/>
|
||||||
</Table.Cell>
|
</Table.Cell>
|
||||||
<Table.Cell>
|
<Table.Cell>
|
||||||
@@ -481,6 +527,20 @@ const ChannelsTable = () => {
|
|||||||
</Button>
|
</Button>
|
||||||
<Button size='small' onClick={updateAllChannelsBalance}
|
<Button size='small' onClick={updateAllChannelsBalance}
|
||||||
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
|
loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
|
||||||
|
<Popup
|
||||||
|
trigger={
|
||||||
|
<Button size='small' loading={loading}>
|
||||||
|
删除禁用渠道
|
||||||
|
</Button>
|
||||||
|
}
|
||||||
|
on='click'
|
||||||
|
flowing
|
||||||
|
hoverable
|
||||||
|
>
|
||||||
|
<Button size='small' loading={loading} negative onClick={deleteAllDisabledChannels}>
|
||||||
|
确认删除
|
||||||
|
</Button>
|
||||||
|
</Popup>
|
||||||
<Pagination
|
<Pagination
|
||||||
floated='right'
|
floated='right'
|
||||||
activePage={activePage}
|
activePage={activePage}
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ import React, { useContext, useEffect, useState } from 'react';
|
|||||||
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
|
import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } from 'semantic-ui-react';
|
||||||
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
|
||||||
import { UserContext } from '../context/User';
|
import { UserContext } from '../context/User';
|
||||||
import { API, getLogo, showError, showSuccess } from '../helpers';
|
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
|
||||||
import { getOAuthState, onGitHubOAuthClicked } from './utils';
|
import { onGitHubOAuthClicked } from './utils';
|
||||||
|
|
||||||
const LoginForm = () => {
|
const LoginForm = () => {
|
||||||
const [inputs, setInputs] = useState({
|
const [inputs, setInputs] = useState({
|
||||||
@@ -68,8 +68,14 @@ const LoginForm = () => {
|
|||||||
if (success) {
|
if (success) {
|
||||||
userDispatch({ type: 'login', payload: data });
|
userDispatch({ type: 'login', payload: data });
|
||||||
localStorage.setItem('user', JSON.stringify(data));
|
localStorage.setItem('user', JSON.stringify(data));
|
||||||
navigate('/');
|
if (username === 'root' && password === '123456') {
|
||||||
showSuccess('登录成功!');
|
navigate('/user/edit');
|
||||||
|
showSuccess('登录成功!');
|
||||||
|
showWarning('请立刻修改默认密码!');
|
||||||
|
} else {
|
||||||
|
navigate('/token');
|
||||||
|
showSuccess('登录成功!');
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
showError(message);
|
showError(message);
|
||||||
}
|
}
|
||||||
@@ -126,7 +132,7 @@ const LoginForm = () => {
|
|||||||
circular
|
circular
|
||||||
color='black'
|
color='black'
|
||||||
icon='github'
|
icon='github'
|
||||||
onClick={()=>onGitHubOAuthClicked(status.github_client_id)}
|
onClick={() => onGitHubOAuthClicked(status.github_client_id)}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<></>
|
<></>
|
||||||
|
|||||||
@@ -130,7 +130,13 @@ const RedemptionsTable = () => {
|
|||||||
setLoading(true);
|
setLoading(true);
|
||||||
let sortedRedemptions = [...redemptions];
|
let sortedRedemptions = [...redemptions];
|
||||||
sortedRedemptions.sort((a, b) => {
|
sortedRedemptions.sort((a, b) => {
|
||||||
return ('' + a[key]).localeCompare(b[key]);
|
if (!isNaN(a[key])) {
|
||||||
|
// If the value is numeric, subtract to sort
|
||||||
|
return a[key] - b[key];
|
||||||
|
} else {
|
||||||
|
// If the value is not numeric, sort as strings
|
||||||
|
return ('' + a[key]).localeCompare(b[key]);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
if (sortedRedemptions[0].id === redemptions[0].id) {
|
if (sortedRedemptions[0].id === redemptions[0].id) {
|
||||||
sortedRedemptions.reverse();
|
sortedRedemptions.reverse();
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ const TokensTable = () => {
|
|||||||
let defaultUrl;
|
let defaultUrl;
|
||||||
|
|
||||||
if (chatLink) {
|
if (chatLink) {
|
||||||
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}"}`;
|
defaultUrl = chatLink + `/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
} else {
|
} else {
|
||||||
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
defaultUrl = `https://chat.oneapi.pro/#/?settings={"key":"sk-${key}","url":"${serverAddress}"}`;
|
||||||
}
|
}
|
||||||
@@ -228,7 +228,13 @@ const TokensTable = () => {
|
|||||||
setLoading(true);
|
setLoading(true);
|
||||||
let sortedTokens = [...tokens];
|
let sortedTokens = [...tokens];
|
||||||
sortedTokens.sort((a, b) => {
|
sortedTokens.sort((a, b) => {
|
||||||
return ('' + a[key]).localeCompare(b[key]);
|
if (!isNaN(a[key])) {
|
||||||
|
// If the value is numeric, subtract to sort
|
||||||
|
return a[key] - b[key];
|
||||||
|
} else {
|
||||||
|
// If the value is not numeric, sort as strings
|
||||||
|
return ('' + a[key]).localeCompare(b[key]);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
if (sortedTokens[0].id === tokens[0].id) {
|
if (sortedTokens[0].id === tokens[0].id) {
|
||||||
sortedTokens.reverse();
|
sortedTokens.reverse();
|
||||||
|
|||||||
@@ -133,7 +133,13 @@ const UsersTable = () => {
|
|||||||
setLoading(true);
|
setLoading(true);
|
||||||
let sortedUsers = [...users];
|
let sortedUsers = [...users];
|
||||||
sortedUsers.sort((a, b) => {
|
sortedUsers.sort((a, b) => {
|
||||||
return ('' + a[key]).localeCompare(b[key]);
|
if (!isNaN(a[key])) {
|
||||||
|
// If the value is numeric, subtract to sort
|
||||||
|
return a[key] - b[key];
|
||||||
|
} else {
|
||||||
|
// If the value is not numeric, sort as strings
|
||||||
|
return ('' + a[key]).localeCompare(b[key]);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
if (sortedUsers[0].id === users[0].id) {
|
if (sortedUsers[0].id === users[0].id) {
|
||||||
sortedUsers.reverse();
|
sortedUsers.reverse();
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ export const CHANNEL_OPTIONS = [
|
|||||||
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
{ key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
|
||||||
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
{ key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
|
||||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
{ key: 19, text: '360 智脑', value: 19, color: 'blue' },
|
||||||
|
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' },
|
||||||
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
|
||||||
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
{ key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
|
||||||
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
{ key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
|
||||||
|
|||||||
@@ -186,4 +186,14 @@ export const verifyJSON = (str) => {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export function shouldShowPrompt(id) {
|
||||||
|
let prompt = localStorage.getItem(`prompt-${id}`);
|
||||||
|
return !prompt;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
export function setPromptShown(id) {
|
||||||
|
localStorage.setItem(`prompt-${id}`, 'true');
|
||||||
|
}
|
||||||
@@ -19,6 +19,8 @@ function type2secretPrompt(type) {
|
|||||||
return '按照如下格式输入:APPID|APISecret|APIKey';
|
return '按照如下格式输入:APPID|APISecret|APIKey';
|
||||||
case 22:
|
case 22:
|
||||||
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
||||||
|
case 23:
|
||||||
|
return '按照如下格式输入:AppId|SecretId|SecretKey';
|
||||||
default:
|
default:
|
||||||
return '请输入渠道对应的鉴权密钥';
|
return '请输入渠道对应的鉴权密钥';
|
||||||
}
|
}
|
||||||
@@ -64,19 +66,22 @@ const EditChannel = () => {
|
|||||||
localModels = ['PaLM-2'];
|
localModels = ['PaLM-2'];
|
||||||
break;
|
break;
|
||||||
case 15:
|
case 15:
|
||||||
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
|
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
|
||||||
break;
|
break;
|
||||||
case 17:
|
case 17:
|
||||||
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
|
localModels = ['qwen-turbo', 'qwen-plus', 'text-embedding-v1'];
|
||||||
break;
|
break;
|
||||||
case 16:
|
case 16:
|
||||||
localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
localModels = ['chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
|
||||||
break;
|
break;
|
||||||
case 18:
|
case 18:
|
||||||
localModels = ['SparkDesk'];
|
localModels = ['SparkDesk'];
|
||||||
break;
|
break;
|
||||||
case 19:
|
case 19:
|
||||||
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
|
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
|
||||||
|
break;
|
||||||
|
case 23:
|
||||||
|
localModels = ['hunyuan'];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
setInputs((inputs) => ({ ...inputs, models: localModels }));
|
||||||
|
|||||||
@@ -1,19 +1,12 @@
|
|||||||
import React from 'react';
|
import React from 'react';
|
||||||
import { Segment, Header } from 'semantic-ui-react';
|
import { Message } from 'semantic-ui-react';
|
||||||
|
|
||||||
const NotFound = () => (
|
const NotFound = () => (
|
||||||
<>
|
<>
|
||||||
<Header
|
<Message negative>
|
||||||
block
|
<Message.Header>页面不存在</Message.Header>
|
||||||
as="h4"
|
<p>请检查你的浏览器地址是否正确</p>
|
||||||
content="404"
|
</Message>
|
||||||
attached="top"
|
|
||||||
icon="info"
|
|
||||||
className="small-icon"
|
|
||||||
/>
|
|
||||||
<Segment attached="bottom">
|
|
||||||
未找到所请求的页面
|
|
||||||
</Segment>
|
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ const EditUser = () => {
|
|||||||
label='密码'
|
label='密码'
|
||||||
name='password'
|
name='password'
|
||||||
type={'password'}
|
type={'password'}
|
||||||
placeholder={'请输入新的密码'}
|
placeholder={'请输入新的密码,最短 8 位'}
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
value={password}
|
value={password}
|
||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
|
|||||||
Reference in New Issue
Block a user