Compare commits
122 Commits
v0.5.6-alp
...
v0.5.12-al
Author | SHA1 | Date | |
---|---|---|---|
|
eed9f5fdf0 | ||
|
f2c51a494c | ||
|
8a4d6f3327 | ||
|
cf4e33cb12 | ||
|
5d60305570 | ||
|
d062bc60e4 | ||
|
39c1882970 | ||
|
9c42c7dfd9 | ||
|
903aaeded0 | ||
|
bdd4be562d | ||
|
37afb313b5 | ||
|
c9ebcab8b8 | ||
|
86261cc656 | ||
|
8491785c9d | ||
|
e848a3f7fa | ||
|
318adf5985 | ||
|
965d7fc3d2 | ||
|
aa3f605894 | ||
|
7b8eff1f22 | ||
|
e80cd508ba | ||
|
d37f836d53 | ||
|
e0b2d1ae47 | ||
|
797ead686b | ||
|
0d22cf9ead | ||
|
48989d4a0b | ||
|
6227eee5bc | ||
|
cbf8f07747 | ||
|
4a96031ce6 | ||
|
92886093ae | ||
|
0c022f17cb | ||
|
83f95935de | ||
|
aa03c89133 | ||
|
505817ca17 | ||
|
cb5a3df616 | ||
|
7772064d87 | ||
|
c50c609565 | ||
|
498dea2dbb | ||
|
c725cc8842 | ||
|
af8908db54 | ||
|
d8029550f7 | ||
|
f44fbe3fe7 | ||
|
1c8922153d | ||
|
f3c07e1451 | ||
|
40ceb29e54 | ||
|
0699ecd0af | ||
|
ee9e746520 | ||
|
a763681c2e | ||
|
b7fcb319da | ||
|
67c64e71c8 | ||
|
97030e27f8 | ||
|
461f5dab56 | ||
|
af378c59af | ||
|
bc6769826b | ||
|
0fe26cc4bd | ||
|
7d6a169669 | ||
|
66f06e5d6f | ||
|
6acb9537a9 | ||
|
7069c49bdf | ||
|
58dee76bf7 | ||
|
5cf23d8698 | ||
|
366b82128f | ||
|
2a70744dbf | ||
|
4c5feee0b6 | ||
|
9ba5388367 | ||
|
379074f7d0 | ||
|
01f7b0186f | ||
|
a3f80a3392 | ||
|
8f5b83562b | ||
|
b7570d5c77 | ||
|
0e73418cdf | ||
|
9889377f0e | ||
|
b273464e77 | ||
|
b4e43d97fd | ||
|
3347a44023 | ||
|
923e24534b | ||
|
b4d67ca614 | ||
|
d85e356b6e | ||
|
495fc628e4 | ||
|
76f9288c34 | ||
|
915d13fdd4 | ||
|
969f539777 | ||
|
54e5f8ecd2 | ||
|
34d517cfa2 | ||
|
ddcaf95f5f | ||
|
1d15157f7d | ||
|
de7b9710a5 | ||
|
58bb3ab6f6 | ||
|
d306cb5229 | ||
|
6c5307d0c4 | ||
|
7c4505bdfc | ||
|
9d43ec57d8 | ||
|
e5311892d1 | ||
|
bc7c9105f4 | ||
|
3fe76c8af7 | ||
|
c70c614018 | ||
|
0d87de697c | ||
|
aec343dc38 | ||
|
89d458b9cf | ||
|
63fafba112 | ||
|
a398f35968 | ||
|
57aa637c77 | ||
|
3b483639a4 | ||
|
22980b4c44 | ||
|
64cdb7eafb | ||
|
824444244b | ||
|
fbe9985f57 | ||
|
a27a5bcc06 | ||
|
e28d4b1741 | ||
|
f073592d39 | ||
|
fa41ca9805 | ||
|
e338de45b6 | ||
|
114587b46f | ||
|
b4b4acc288 | ||
|
d663de3e3a | ||
|
a85ecace2e | ||
|
fbdea91ea1 | ||
|
8d34b7a77e | ||
|
cbd62011b8 | ||
|
4701897e2e | ||
|
0f6c132a80 | ||
|
3cac45dc85 | ||
|
47c08c72ce |
11
.github/workflows/linux-release.yml
vendored
@@ -7,6 +7,11 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -18,13 +23,13 @@ jobs:
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (theme default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
git describe --tags > VERSION
|
||||
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
|
11
.github/workflows/macos-release.yml
vendored
@@ -7,6 +7,11 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
jobs:
|
||||
release:
|
||||
runs-on: macos-latest
|
||||
@@ -18,13 +23,13 @@ jobs:
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (theme default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
git describe --tags > VERSION
|
||||
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh
|
||||
cd ..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
|
11
.github/workflows/windows-release.yml
vendored
@@ -7,6 +7,11 @@ on:
|
||||
tags:
|
||||
- '*'
|
||||
- '!*-alpha*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
name:
|
||||
description: 'reason'
|
||||
required: false
|
||||
jobs:
|
||||
release:
|
||||
runs-on: windows-latest
|
||||
@@ -21,14 +26,14 @@ jobs:
|
||||
- uses: actions/setup-node@v3
|
||||
with:
|
||||
node-version: 16
|
||||
- name: Build Frontend
|
||||
- name: Build Frontend (theme default)
|
||||
env:
|
||||
CI: ""
|
||||
run: |
|
||||
cd web
|
||||
cd web/default
|
||||
npm install
|
||||
REACT_APP_VERSION=$(git describe --tags) npm run build
|
||||
cd ..
|
||||
cd ../..
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v3
|
||||
with:
|
||||
|
3
.gitignore
vendored
@@ -5,4 +5,5 @@ upload
|
||||
*.db
|
||||
build
|
||||
*.db-journal
|
||||
logs
|
||||
logs
|
||||
data
|
17
Dockerfile
@@ -1,10 +1,15 @@
|
||||
FROM node:16 as builder
|
||||
|
||||
WORKDIR /build
|
||||
COPY web/package.json .
|
||||
RUN npm install
|
||||
COPY ./web .
|
||||
WORKDIR /web
|
||||
COPY ./VERSION .
|
||||
COPY ./web .
|
||||
|
||||
WORKDIR /web/default
|
||||
RUN npm install
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
|
||||
WORKDIR /web/berry
|
||||
RUN npm install
|
||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
|
||||
|
||||
FROM golang AS builder2
|
||||
@@ -17,7 +22,7 @@ WORKDIR /build
|
||||
ADD go.mod go.sum ./
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
COPY --from=builder /build/build ./web/build
|
||||
COPY --from=builder /web/build ./web/build
|
||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
|
||||
|
||||
FROM alpine
|
||||
@@ -30,4 +35,4 @@ RUN apk update \
|
||||
COPY --from=builder2 /build/one-api /
|
||||
EXPOSE 3000
|
||||
WORKDIR /data
|
||||
ENTRYPOINT ["/one-api"]
|
||||
ENTRYPOINT ["/one-api"]
|
@@ -3,7 +3,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
|
||||
1. Support for multiple large models:
|
||||
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||
+ [x] [Anthropic Claude Series Models](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 Series Models](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google)
|
||||
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn)
|
||||
@@ -189,6 +189,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
|
||||
|
||||
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
|
||||
|
||||
[](https://zeabur.com/templates/7Q0KO3)
|
||||
|
||||
1. First, fork the code.
|
||||
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console.
|
||||
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
|
||||
|
@@ -3,7 +3,7 @@
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に
|
||||
1. 複数の大型モデルをサポート:
|
||||
+ [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
|
||||
+ [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google)
|
||||
+ [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
|
||||
@@ -190,6 +190,8 @@ Please refer to the [environment variables](#environment-variables) section for
|
||||
|
||||
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
|
||||
|
||||
[](https://zeabur.com/templates/7Q0KO3)
|
||||
|
||||
1. まず、コードをフォークする。
|
||||
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
|
||||
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
|
||||
|
90
README.md
@@ -4,7 +4,7 @@
|
||||
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
@@ -51,34 +51,29 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
<a href="https://iamazing.cn/page/reward">赞赏支持</a>
|
||||
</p>
|
||||
|
||||
> **Note**
|
||||
> [!NOTE]
|
||||
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
|
||||
>
|
||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
|
||||
|
||||
> **Warning**
|
||||
> [!WARNING]
|
||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
|
||||
|
||||
> **Warning**
|
||||
> [!WARNING]
|
||||
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`!
|
||||
|
||||
## 功能
|
||||
1. 支持多种大模型:
|
||||
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference))
|
||||
+ [x] [Anthropic Claude 系列模型](https://anthropic.com)
|
||||
+ [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google)
|
||||
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
|
||||
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
|
||||
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
|
||||
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
|
||||
+ [x] [360 智脑](https://ai.360.cn)
|
||||
2. 支持配置镜像以及众多第三方代理服务:
|
||||
+ [x] [OpenAI-SB](https://openai-sb.com)
|
||||
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
|
||||
+ [x] [API2D](https://api2d.com/r/197971)
|
||||
+ [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
|
||||
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`)
|
||||
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
|
||||
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729)
|
||||
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
|
||||
3. 支持通过**负载均衡**的方式访问多个渠道。
|
||||
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
|
||||
5. 支持**多机部署**,[详见此处](#多机部署)。
|
||||
@@ -91,26 +86,34 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
|
||||
12. 支持**用户邀请奖励**。
|
||||
13. 支持以美元为单位显示额度。
|
||||
14. 支持发布公告,设置充值链接,设置新用户初始额度。
|
||||
15. 支持模型映射,重定向用户的请求模型。
|
||||
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功。
|
||||
16. 支持失败自动重试。
|
||||
17. 支持绘图接口。
|
||||
18. 支持丰富的**自定义**设置,
|
||||
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。
|
||||
19. 支持丰富的**自定义**设置,
|
||||
1. 支持自定义系统名称,logo 以及页脚。
|
||||
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
|
||||
19. 支持通过系统访问令牌访问管理 API。
|
||||
20. 支持 Cloudflare Turnstile 用户校验。
|
||||
21. 支持用户管理,支持**多种用户登录注册方式**:
|
||||
20. 支持通过系统访问令牌访问管理 API(bearer token,用以替代 cookie,你可以自行抓包来查看 API 的用法)。
|
||||
21. 支持 Cloudflare Turnstile 用户校验。
|
||||
22. 支持用户管理,支持**多种用户登录注册方式**:
|
||||
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
|
||||
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
|
||||
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
|
||||
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
|
||||
|
||||
## 部署
|
||||
### 基于 Docker 进行部署
|
||||
部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
|
||||
```shell
|
||||
# 使用 SQLite 的部署命令:
|
||||
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
|
||||
# 例如:
|
||||
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
|
||||
```
|
||||
|
||||
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
|
||||
|
||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
|
||||
|
||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
|
||||
|
||||
@@ -152,6 +155,19 @@ sudo service nginx restart
|
||||
|
||||
初始账号用户名为 `root`,密码为 `123456`。
|
||||
|
||||
|
||||
### 基于 Docker Compose 进行部署
|
||||
|
||||
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
|
||||
|
||||
```shell
|
||||
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
|
||||
docker-compose up -d
|
||||
|
||||
# 查看部署状态
|
||||
docker-compose ps
|
||||
```
|
||||
|
||||
### 手动部署
|
||||
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
|
||||
```shell
|
||||
@@ -239,7 +255,9 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
||||
<summary><strong>部署到 Zeabur</strong></summary>
|
||||
<div>
|
||||
|
||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用。
|
||||
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
|
||||
|
||||
[](https://zeabur.com/templates/7Q0KO3)
|
||||
|
||||
1. 首先 fork 一份代码。
|
||||
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。
|
||||
@@ -254,6 +272,17 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><strong>部署到 Render</strong></summary>
|
||||
<div>
|
||||
|
||||
> Render 提供免费额度,绑卡后可以进一步提升额度
|
||||
|
||||
Render 可以直接部署 docker 镜像,不需要 fork 仓库:https://dashboard.render.com
|
||||
|
||||
</div>
|
||||
</details>
|
||||
|
||||
## 配置
|
||||
系统本身开箱即用。
|
||||
|
||||
@@ -281,10 +310,11 @@ OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
|
||||
```mermaid
|
||||
graph LR
|
||||
A(用户)
|
||||
A --->|请求| B(One API)
|
||||
A --->|使用 One API 分发的 key 进行请求| B(One API)
|
||||
B -->|中继请求| C(OpenAI)
|
||||
B -->|中继请求| D(Azure)
|
||||
B -->|中继请求| E(其他下游渠道)
|
||||
B -->|中继请求| E(其他 OpenAI API 格式下游渠道)
|
||||
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
|
||||
```
|
||||
|
||||
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
|
||||
@@ -332,6 +362,13 @@ graph LR
|
||||
13. 请求频率限制:
|
||||
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
|
||||
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
|
||||
14. 编码器缓存设置:
|
||||
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
|
||||
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
|
||||
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
|
||||
16. `SQLITE_BUSY_TIMEOUT`:SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
|
||||
17. `GEMINI_SAFETY_SETTING`:Gemini 的安全设置,默认 `BLOCK_NONE`。
|
||||
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
|
||||
|
||||
### 命令行参数
|
||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
|
||||
@@ -371,6 +408,15 @@ https://openai.justsong.cn
|
||||
+ 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
|
||||
6. 报错:`当前分组负载已饱和,请稍后再试`
|
||||
+ 上游通道 429 了。
|
||||
7. 升级之后我的数据会丢失吗?
|
||||
+ 如果使用 MySQL,不会。
|
||||
+ 如果使用 SQLite,需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
|
||||
8. 升级之前数据库需要做变更吗?
|
||||
+ 一般情况下不需要,系统将在初始化的时候自动调整。
|
||||
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
|
||||
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`?
|
||||
+ 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
|
||||
+ 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
|
||||
|
||||
## 相关项目
|
||||
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统
|
||||
|
@@ -21,12 +21,9 @@ var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
|
||||
var DisplayInCurrencyEnabled = true
|
||||
var DisplayTokenStatEnabled = true
|
||||
|
||||
var UsingSQLite = false
|
||||
|
||||
// Any options with "Secret", "Token" in its key won't be return by GetOptions
|
||||
|
||||
var SessionSecret = uuid.New().String()
|
||||
var SQLitePath = "one-api.db"
|
||||
|
||||
var OptionMap map[string]string
|
||||
var OptionMapRWMutex sync.RWMutex
|
||||
@@ -81,6 +78,7 @@ var QuotaForInviter = 0
|
||||
var QuotaForInvitee = 0
|
||||
var ChannelDisableThreshold = 5.0
|
||||
var AutomaticDisableChannelEnabled = false
|
||||
var AutomaticEnableChannelEnabled = false
|
||||
var QuotaRemindThreshold = 1000
|
||||
var PreConsumedQuota = 500
|
||||
var ApproximateTokenEnabled = false
|
||||
@@ -98,6 +96,16 @@ var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
|
||||
var BatchUpdateEnabled = false
|
||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
|
||||
|
||||
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second
|
||||
|
||||
var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
|
||||
|
||||
var Theme = GetOrDefaultString("THEME", "default")
|
||||
var ValidThemes = map[string]bool{
|
||||
"default": true,
|
||||
"berry": true,
|
||||
}
|
||||
|
||||
const (
|
||||
RequestIdKey = "X-Oneapi-Request-Id"
|
||||
)
|
||||
@@ -156,9 +164,10 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
ChannelStatusUnknown = 0
|
||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
ChannelStatusDisabled = 2 // also don't use 0
|
||||
ChannelStatusUnknown = 0
|
||||
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
|
||||
ChannelStatusManuallyDisabled = 2 // also don't use 0
|
||||
ChannelStatusAutoDisabled = 3
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -185,30 +194,34 @@ const (
|
||||
ChannelTypeOpenRouter = 20
|
||||
ChannelTypeAIProxyLibrary = 21
|
||||
ChannelTypeFastGPT = 22
|
||||
ChannelTypeTencent = 23
|
||||
ChannelTypeGemini = 24
|
||||
)
|
||||
|
||||
var ChannelBaseURLs = []string{
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"", // 0
|
||||
"https://api.openai.com", // 1
|
||||
"https://oa.api2d.net", // 2
|
||||
"", // 3
|
||||
"https://api.closeai-proxy.xyz", // 4
|
||||
"https://api.openai-sb.com", // 5
|
||||
"https://api.openaimax.com", // 6
|
||||
"https://api.ohmygpt.com", // 7
|
||||
"", // 8
|
||||
"https://api.caipacity.com", // 9
|
||||
"https://api.aiproxy.io", // 10
|
||||
"", // 11
|
||||
"https://api.api2gpt.com", // 12
|
||||
"https://api.aigc2d.com", // 13
|
||||
"https://api.anthropic.com", // 14
|
||||
"https://aip.baidubce.com", // 15
|
||||
"https://open.bigmodel.cn", // 16
|
||||
"https://dashscope.aliyuncs.com", // 17
|
||||
"", // 18
|
||||
"https://ai.360.cn", // 19
|
||||
"https://openrouter.ai/api", // 20
|
||||
"https://api.aiproxy.io", // 21
|
||||
"https://fastgpt.run/api/openapi", // 22
|
||||
"https://hunyuan.cloud.tencent.com", //23
|
||||
"", //24
|
||||
}
|
||||
|
7
common/database.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package common
|
||||
|
||||
var UsingSQLite = false
|
||||
var UsingPostgreSQL = false
|
||||
|
||||
var SQLitePath = "one-api.db"
|
||||
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)
|
@@ -1,11 +1,13 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func SendEmail(subject string, receiver string, content string) error {
|
||||
@@ -13,15 +15,32 @@ func SendEmail(subject string, receiver string, content string) error {
|
||||
SMTPFrom = SMTPAccount
|
||||
}
|
||||
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
|
||||
|
||||
// Extract domain from SMTPFrom
|
||||
parts := strings.Split(SMTPFrom, "@")
|
||||
var domain string
|
||||
if len(parts) > 1 {
|
||||
domain = parts[1]
|
||||
}
|
||||
// Generate a unique Message-ID
|
||||
buf := make([]byte, 16)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
|
||||
|
||||
mail := []byte(fmt.Sprintf("To: %s\r\n"+
|
||||
"From: %s<%s>\r\n"+
|
||||
"Subject: %s\r\n"+
|
||||
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
|
||||
"Date: %s\r\n"+
|
||||
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, content))
|
||||
receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content))
|
||||
auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
|
||||
addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
|
||||
to := strings.Split(receiver, ";")
|
||||
var err error
|
||||
|
||||
if SMTPPort == 465 {
|
||||
tlsConfig := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
@@ -16,7 +17,13 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
contentType := c.Request.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/json") {
|
||||
err = json.Unmarshal(requestBody, &v)
|
||||
} else {
|
||||
// skip for now
|
||||
// TODO: someday non json request have variant model, we will need to implementation this
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -24,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
111
common/image/image.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
// Regex to match data URL pattern
|
||||
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
|
||||
|
||||
func IsImageUrl(url string) (bool, error) {
|
||||
resp, err := http.Head(url)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
|
||||
return false, nil
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
img, _, err := image.DecodeConfig(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return img.Width, img.Height, nil
|
||||
}
|
||||
|
||||
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
|
||||
// Check if the URL is a data URL
|
||||
matches := dataURLPattern.FindStringSubmatch(url)
|
||||
if len(matches) == 3 {
|
||||
// URL is a data URL
|
||||
mimeType = "image/" + matches[1]
|
||||
data = matches[2]
|
||||
return
|
||||
}
|
||||
|
||||
isImage, err := IsImageUrl(url)
|
||||
if !isImage {
|
||||
return
|
||||
}
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
buffer := bytes.NewBuffer(nil)
|
||||
_, err = buffer.ReadFrom(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mimeType = resp.Header.Get("Content-Type")
|
||||
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
|
||||
)
|
||||
|
||||
var readerPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &bytes.Reader{}
|
||||
},
|
||||
}
|
||||
|
||||
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
|
||||
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
reader := readerPool.Get().(*bytes.Reader)
|
||||
defer readerPool.Put(reader)
|
||||
reader.Reset(decoded)
|
||||
|
||||
img, _, err := image.DecodeConfig(reader)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return img.Width, img.Height, nil
|
||||
}
|
||||
|
||||
func GetImageSize(image string) (width int, height int, err error) {
|
||||
if strings.HasPrefix(image, "data:image/") {
|
||||
return GetImageSizeFromBase64(image)
|
||||
}
|
||||
return GetImageSizeFromUrl(image)
|
||||
}
|
171
common/image/image_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package image_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
img "one-api/common/image"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
_ "golang.org/x/image/webp"
|
||||
)
|
||||
|
||||
type CountingReader struct {
|
||||
reader io.Reader
|
||||
BytesRead int
|
||||
}
|
||||
|
||||
func (r *CountingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.reader.Read(p)
|
||||
r.BytesRead += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
var (
|
||||
cases = []struct {
|
||||
url string
|
||||
format string
|
||||
width int
|
||||
height int
|
||||
}{
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
|
||||
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
|
||||
}
|
||||
)
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
// Bytes read: varies sometimes
|
||||
// jpeg: 1063892
|
||||
// png: 294462
|
||||
// webp: 99529
|
||||
// gif: 956153
|
||||
// jpeg#01: 32805
|
||||
for _, c := range cases {
|
||||
t.Run("Decode:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
reader := &CountingReader{reader: resp.Body}
|
||||
img, format, err := image.Decode(reader)
|
||||
assert.NoError(t, err)
|
||||
size := img.Bounds().Size()
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, size.X)
|
||||
assert.Equal(t, c.height, size.Y)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
|
||||
// Bytes read:
|
||||
// jpeg: 4096
|
||||
// png: 4096
|
||||
// webp: 4096
|
||||
// gif: 4096
|
||||
// jpeg#01: 4096
|
||||
for _, c := range cases {
|
||||
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
reader := &CountingReader{reader: resp.Body}
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, config.Width)
|
||||
assert.Equal(t, c.height, config.Height)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64(t *testing.T) {
|
||||
// Bytes read:
|
||||
// jpeg: 1063892
|
||||
// png: 294462
|
||||
// webp: 99072
|
||||
// gif: 953856
|
||||
// jpeg#01: 32805
|
||||
for _, c := range cases {
|
||||
t.Run("Decode:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
reader := &CountingReader{reader: body}
|
||||
img, format, err := image.Decode(reader)
|
||||
assert.NoError(t, err)
|
||||
size := img.Bounds().Size()
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, size.X)
|
||||
assert.Equal(t, c.height, size.Y)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
|
||||
// Bytes read:
|
||||
// jpeg: 1536
|
||||
// png: 768
|
||||
// webp: 768
|
||||
// gif: 1536
|
||||
// jpeg#01: 3840
|
||||
for _, c := range cases {
|
||||
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
|
||||
reader := &CountingReader{reader: body}
|
||||
config, format, err := image.DecodeConfig(reader)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.format, format)
|
||||
assert.Equal(t, c.width, config.Width)
|
||||
assert.Equal(t, c.height, config.Height)
|
||||
t.Logf("Bytes read: %d", reader.BytesRead)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetImageSize(t *testing.T) {
|
||||
for i, c := range cases {
|
||||
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||
width, height, err := img.GetImageSize(c.url)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.width, width)
|
||||
assert.Equal(t, c.height, height)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetImageSizeFromBase64(t *testing.T) {
|
||||
for i, c := range cases {
|
||||
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
|
||||
resp, err := http.Get(c.url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
encoded := base64.StdEncoding.EncodeToString(data)
|
||||
width, height, err := img.GetImageSizeFromBase64(encoded)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.width, width)
|
||||
assert.Equal(t, c.height, height)
|
||||
})
|
||||
}
|
||||
}
|
@@ -36,7 +36,11 @@ func init() {
|
||||
}
|
||||
|
||||
if os.Getenv("SESSION_SECRET") != "" {
|
||||
SessionSecret = os.Getenv("SESSION_SECRET")
|
||||
if os.Getenv("SESSION_SECRET") == "random_string" {
|
||||
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
|
||||
} else {
|
||||
SessionSecret = os.Getenv("SESSION_SECRET")
|
||||
}
|
||||
}
|
||||
if os.Getenv("SQLITE_PATH") != "" {
|
||||
SQLitePath = os.Getenv("SQLITE_PATH")
|
||||
|
@@ -3,8 +3,32 @@ package common
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var DalleSizeRatios = map[string]map[string]float64{
|
||||
"dall-e-2": {
|
||||
"256x256": 1,
|
||||
"512x512": 1.125,
|
||||
"1024x1024": 1.25,
|
||||
},
|
||||
"dall-e-3": {
|
||||
"1024x1024": 1,
|
||||
"1024x1792": 2,
|
||||
"1792x1024": 2,
|
||||
},
|
||||
}
|
||||
|
||||
var DalleGenerationImageAmounts = map[string][2]int{
|
||||
"dall-e-2": {1, 10},
|
||||
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
|
||||
}
|
||||
|
||||
var DalleImagePromptLengthLimitations = map[string]int{
|
||||
"dall-e-2": 1000,
|
||||
"dall-e-3": 4000,
|
||||
}
|
||||
|
||||
// ModelRatio
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Blfmc9dlf
|
||||
@@ -19,12 +43,17 @@ var ModelRatio = map[string]float64{
|
||||
"gpt-4-32k": 30,
|
||||
"gpt-4-32k-0314": 30,
|
||||
"gpt-4-32k-0613": 30,
|
||||
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
|
||||
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-0301": 0.75,
|
||||
"gpt-3.5-turbo-0613": 0.75,
|
||||
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
|
||||
"gpt-3.5-turbo-16k-0613": 1.5,
|
||||
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
|
||||
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
|
||||
"davinci-002": 1, // $0.002 / 1K tokens
|
||||
"babbage-002": 0.2, // $0.0004 / 1K tokens
|
||||
"text-ada-001": 0.2,
|
||||
"text-babbage-001": 0.25,
|
||||
"text-curie-001": 1,
|
||||
@@ -32,7 +61,11 @@ var ModelRatio = map[string]float64{
|
||||
"text-davinci-003": 10,
|
||||
"text-davinci-edit-001": 10,
|
||||
"code-davinci-edit-001": 10,
|
||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
|
||||
"tts-1": 7.5, // $0.015 / 1K characters
|
||||
"tts-1-1106": 7.5,
|
||||
"tts-1-hd": 15, // $0.030 / 1K characters
|
||||
"tts-1-hd-1106": 15,
|
||||
"davinci": 10,
|
||||
"curie": 10,
|
||||
"babbage": 10,
|
||||
@@ -41,25 +74,34 @@ var ModelRatio = map[string]float64{
|
||||
"text-search-ada-doc-001": 10,
|
||||
"text-moderation-stable": 0.1,
|
||||
"text-moderation-latest": 0.1,
|
||||
"dall-e": 8,
|
||||
"dall-e-2": 8, // $0.016 - $0.020 / image
|
||||
"dall-e-3": 20, // $0.040 - $0.120 / image
|
||||
"claude-instant-1": 0.815, // $1.63 / 1M tokens
|
||||
"claude-2": 5.51, // $11.02 / 1M tokens
|
||||
"claude-2.0": 5.51, // $11.02 / 1M tokens
|
||||
"claude-2.1": 5.51, // $11.02 / 1M tokens
|
||||
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
|
||||
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
|
||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||
"PaLM-2": 1,
|
||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
|
||||
"qwen-turbo": 0.8572, // ¥0.012 / 1k tokens
|
||||
"qwen-plus": 10, // ¥0.14 / 1k tokens
|
||||
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
|
||||
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
|
||||
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
|
||||
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
|
||||
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
|
||||
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
|
||||
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
|
||||
"360GPT_S2_V9.4": 0.8572, // ¥0.012 / 1k tokens
|
||||
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
|
||||
}
|
||||
|
||||
func ModelRatio2JSONString() string {
|
||||
@@ -76,6 +118,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||
}
|
||||
|
||||
func GetModelRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
|
||||
name = strings.TrimSuffix(name, "-internet")
|
||||
}
|
||||
ratio, ok := ModelRatio[name]
|
||||
if !ok {
|
||||
SysError("model ratio not found: " + name)
|
||||
@@ -86,9 +131,24 @@ func GetModelRatio(name string) float64 {
|
||||
|
||||
func GetCompletionRatio(name string) float64 {
|
||||
if strings.HasPrefix(name, "gpt-3.5") {
|
||||
if strings.HasSuffix(name, "1106") {
|
||||
return 2
|
||||
}
|
||||
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
|
||||
// TODO: clear this after 2023-12-11
|
||||
now := time.Now()
|
||||
// https://platform.openai.com/docs/models/continuous-model-upgrades
|
||||
// if after 2023-12-11, use 2
|
||||
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
|
||||
return 2
|
||||
}
|
||||
}
|
||||
return 1.333333
|
||||
}
|
||||
if strings.HasPrefix(name, "gpt-4") {
|
||||
if strings.HasSuffix(name, "preview") {
|
||||
return 3
|
||||
}
|
||||
return 2
|
||||
}
|
||||
if strings.HasPrefix(name, "claude-instant-1") {
|
||||
|
@@ -196,6 +196,21 @@ func GetOrDefault(env string, defaultValue int) int {
|
||||
return num
|
||||
}
|
||||
|
||||
func GetOrDefaultString(env string, defaultValue string) string {
|
||||
if env == "" || os.Getenv(env) == "" {
|
||||
return defaultValue
|
||||
}
|
||||
return os.Getenv(env)
|
||||
}
|
||||
|
||||
func MessageWithRequestId(message string, id string) string {
|
||||
return fmt.Sprintf("%s (request id: %s)", message, id)
|
||||
}
|
||||
|
||||
func String2Int(str string) int {
|
||||
num, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return num
|
||||
}
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
func GetSubscription(c *gin.Context) {
|
||||
@@ -27,12 +28,12 @@ func GetSubscription(c *gin.Context) {
|
||||
expiredTime = 0
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := OpenAIError{
|
||||
Error := openai.Error{
|
||||
Message: err.Error(),
|
||||
Type: "upstream_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
"error": Error,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -69,12 +70,12 @@ func GetUsage(c *gin.Context) {
|
||||
quota, err = model.GetUserUsedQuota(userId)
|
||||
}
|
||||
if err != nil {
|
||||
openAIError := OpenAIError{
|
||||
Error := openai.Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
"error": Error,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/util"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -92,7 +93,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
|
||||
for k := range headers {
|
||||
req.Header.Add(k, headers.Get(k))
|
||||
}
|
||||
res, err := httpClient.Do(req)
|
||||
res, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -5,19 +5,25 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/util"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||
func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
@@ -42,14 +48,14 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
}
|
||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model)
|
||||
requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||
} else {
|
||||
if channel.GetBaseURL() != "" {
|
||||
requestURL = channel.GetBaseURL()
|
||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
||||
requestURL = baseURL
|
||||
}
|
||||
requestURL += "/v1/chat/completions"
|
||||
}
|
||||
|
||||
requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
||||
}
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
@@ -64,28 +70,35 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := httpClient.Do(req)
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
var response openai.SlimTextResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
if response.Error.Message == "" {
|
||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
||||
}
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *ChatRequest {
|
||||
testRequest := &ChatRequest{
|
||||
func buildTestRequest() *openai.ChatRequest {
|
||||
testRequest := &openai.ChatRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
}
|
||||
testMessage := Message{
|
||||
testMessage := openai.Message{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
}
|
||||
@@ -136,20 +149,32 @@ func TestChannel(c *gin.Context) {
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
func notifyRootUser(subject string, content string) {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
// enable & notify
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
@@ -172,20 +197,21 @@ func testAllChannels(notify bool) error {
|
||||
}
|
||||
go func() {
|
||||
for _, channel := range channels {
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
continue
|
||||
}
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if milliseconds > disableThreshold {
|
||||
if isChannelEnabled && milliseconds > disableThreshold {
|
||||
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if shouldDisableChannel(openaiErr, -1) {
|
||||
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
}
|
||||
|
@@ -127,6 +127,23 @@ func DeleteChannel(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func DeleteDisabledChannel(c *gin.Context) {
|
||||
rows, err := model.DeleteDisabledChannel()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": rows,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateChannel(c *gin.Context) {
|
||||
channel := model.Channel{}
|
||||
err := c.ShouldBindJSON(&channel)
|
||||
|
@@ -2,8 +2,8 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/models/list
|
||||
@@ -55,12 +55,21 @@ func init() {
|
||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
openAIModels = []OpenAIModels{
|
||||
{
|
||||
Id: "dall-e",
|
||||
Id: "dall-e-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "dall-e",
|
||||
Root: "dall-e-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "dall-e-3",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "dall-e-3",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
@@ -72,6 +81,42 @@ func init() {
|
||||
Root: "whisper-1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-1106",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-hd",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-hd",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "tts-1-hd-1106",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "tts-1-hd-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo",
|
||||
Object: "model",
|
||||
@@ -117,6 +162,15 @@ func init() {
|
||||
Root: "gpt-3.5-turbo-16k-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-1106",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-3.5-turbo-1106",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-3.5-turbo-instruct",
|
||||
Object: "model",
|
||||
@@ -180,6 +234,24 @@ func init() {
|
||||
Root: "gpt-4-32k-0613",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-1106-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-1106-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gpt-4-vision-preview",
|
||||
Object: "model",
|
||||
Created: 1699593571,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "gpt-4-vision-preview",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-ada-002",
|
||||
Object: "model",
|
||||
@@ -270,11 +342,29 @@ func init() {
|
||||
Root: "code-davinci-edit-001",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "davinci-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "davinci-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "babbage-002",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "openai",
|
||||
Permission: permission,
|
||||
Root: "babbage-002",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-instant-1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anturopic",
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-instant-1",
|
||||
Parent: nil,
|
||||
@@ -283,11 +373,29 @@ func init() {
|
||||
Id: "claude-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anturopic",
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.1",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.1",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "claude-2.0",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "anthropic",
|
||||
Permission: permission,
|
||||
Root: "claude-2.0",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot",
|
||||
Object: "model",
|
||||
@@ -306,6 +414,15 @@ func init() {
|
||||
Root: "ERNIE-Bot-turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "ERNIE-Bot-4",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "baidu",
|
||||
Permission: permission,
|
||||
Root: "ERNIE-Bot-4",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "Embedding-V1",
|
||||
Object: "model",
|
||||
@@ -319,11 +436,38 @@ func init() {
|
||||
Id: "PaLM-2",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google",
|
||||
OwnedBy: "google palm",
|
||||
Permission: permission,
|
||||
Root: "PaLM-2",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google gemini",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "gemini-pro-vision",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "google gemini",
|
||||
Permission: permission,
|
||||
Root: "gemini-pro-vision",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_turbo",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "zhipu",
|
||||
Permission: permission,
|
||||
Root: "chatglm_turbo",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "chatglm_pro",
|
||||
Object: "model",
|
||||
@@ -369,6 +513,24 @@ func init() {
|
||||
Root: "qwen-plus",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "qwen-max-longcontext",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "ali",
|
||||
Permission: permission,
|
||||
Root: "qwen-max-longcontext",
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "text-embedding-v1",
|
||||
Object: "model",
|
||||
@@ -424,12 +586,12 @@ func init() {
|
||||
Parent: nil,
|
||||
},
|
||||
{
|
||||
Id: "360GPT_S2_V9.4",
|
||||
Id: "hunyuan",
|
||||
Object: "model",
|
||||
Created: 1677649963,
|
||||
OwnedBy: "360",
|
||||
OwnedBy: "tencent",
|
||||
Permission: permission,
|
||||
Root: "360GPT_S2_V9.4",
|
||||
Root: "hunyuan",
|
||||
Parent: nil,
|
||||
},
|
||||
}
|
||||
@@ -451,14 +613,14 @@ func RetrieveModel(c *gin.Context) {
|
||||
if model, ok := openAIModelsMap[modelId]; ok {
|
||||
c.JSON(200, model)
|
||||
} else {
|
||||
openAIError := OpenAIError{
|
||||
Error := openai.Error{
|
||||
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
|
||||
Type: "invalid_request_error",
|
||||
Param: "model",
|
||||
Code: "model_not_found",
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"error": openAIError,
|
||||
"error": Error,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -42,11 +42,19 @@ func UpdateOption(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switch option.Key {
|
||||
case "Theme":
|
||||
if !common.ValidThemes[option.Value] {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无效的主题",
|
||||
})
|
||||
return
|
||||
}
|
||||
case "GitHubOAuthEnabled":
|
||||
if option.Value == "true" && common.GitHubClientId == "" {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client ID 以及 GitHub Client Secret!",
|
||||
"message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
@@ -1,329 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
type AliMessage struct {
|
||||
User string `json:"user"`
|
||||
Bot string `json:"bot"`
|
||||
}
|
||||
|
||||
type AliInput struct {
|
||||
Prompt string `json:"prompt"`
|
||||
History []AliMessage `json:"history"`
|
||||
}
|
||||
|
||||
type AliParameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
}
|
||||
|
||||
type AliChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input AliInput `json:"input"`
|
||||
Parameters AliParameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type AliEmbedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type AliEmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []AliEmbedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
type AliError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type AliUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type AliOutput struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type AliChatResponse struct {
|
||||
Output AliOutput `json:"output"`
|
||||
Usage AliUsage `json:"usage"`
|
||||
AliError
|
||||
}
|
||||
|
||||
func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
|
||||
messages := make([]AliMessage, 0, len(request.Messages))
|
||||
prompt := ""
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: "Okay",
|
||||
})
|
||||
continue
|
||||
} else {
|
||||
if i == len(request.Messages)-1 {
|
||||
prompt = message.Content
|
||||
break
|
||||
}
|
||||
messages = append(messages, AliMessage{
|
||||
User: message.Content,
|
||||
Bot: request.Messages[i+1].Content,
|
||||
})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return &AliChatRequest{
|
||||
Model: request.Model,
|
||||
Input: AliInput{
|
||||
Prompt: prompt,
|
||||
History: messages,
|
||||
},
|
||||
//Parameters: AliParameters{ // ChatGPT's parameters are not compatible with Ali's
|
||||
// TopP: request.TopP,
|
||||
// TopK: 50,
|
||||
// //Seed: 0,
|
||||
// //EnableSearch: false,
|
||||
//},
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
|
||||
return &AliEmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var aliResponse AliEmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
|
||||
choice := OpenAITextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Role: "assistant",
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Usage: Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "ernie-bot",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse AliChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aliHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var aliResponse AliChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
@@ -1,149 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
audioModel := "whisper-1"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
modelRatio := common.GetModelRatio(audioModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[audioModel] != "" {
|
||||
audioModel = modelMap[audioModel]
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
requestBody := c.Request.Body
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var audioResponse AudioResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
quota := countTokenText(audioResponse.Text, audioModel)
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, audioModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &audioResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,182 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
imageModel := "dall-e"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
var imageRequest ImageRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// Prompt validation
|
||||
if imageRequest.Prompt == "" {
|
||||
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Not "256x256", "512x512", or "1024x1024"
|
||||
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
|
||||
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// N should between 1 and 10
|
||||
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
|
||||
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
// Size
|
||||
if imageRequest.Size == "256x256" {
|
||||
sizeRatio = 1
|
||||
} else if imageRequest.Size == "512x512" {
|
||||
sizeRatio = 1.125
|
||||
} else if imageRequest.Size == "1024x1024" {
|
||||
sizeRatio = 1.25
|
||||
}
|
||||
quota := int(ratio*sizeRatio*1000) * imageRequest.N
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var textResponse ImageResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,178 +0,0 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var stopFinishReason = "stop"
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
}
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
if common.ApproximateTokenEnabled {
|
||||
return int(float64(len(text)) * 0.38)
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func countTokenMessages(messages []Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Content)
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
func countTokenInput(input any, model string) int {
|
||||
switch input.(type) {
|
||||
case string:
|
||||
return countTokenText(input.(string), model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range input.([]string) {
|
||||
text += s
|
||||
}
|
||||
return countTokenText(text, model)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func countTokenText(text string, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
||||
|
||||
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
|
||||
openAIError := OpenAIError{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Code: code,
|
||||
}
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: openAIError,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func setEventStreamHeaders(c *gin.Context) {
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Transfer-Encoding", "chunked")
|
||||
c.Writer.Header().Set("X-Accel-Buffering", "no")
|
||||
}
|
||||
|
||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
|
||||
openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: OpenAIError{
|
||||
Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var textResponse TextResponse
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
openAIErrorWithStatusCode.OpenAIError = textResponse.Error
|
||||
return
|
||||
}
|
@@ -4,196 +4,53 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/controller"
|
||||
"one-api/relay/util"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudio
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type ImageRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n"`
|
||||
Size string `json:"size"`
|
||||
}
|
||||
|
||||
type AudioResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
OpenAIError
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error OpenAIError `json:"error"`
|
||||
}
|
||||
|
||||
type OpenAITextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type OpenAITextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []OpenAITextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type OpenAIEmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []OpenAIEmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
relayMode := constant.RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
relayMode = RelayModeChatCompletions
|
||||
relayMode = constant.RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||
relayMode = RelayModeCompletions
|
||||
relayMode = constant.RelayModeCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
relayMode = constant.RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
relayMode = constant.RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
relayMode = RelayModeModerations
|
||||
relayMode = constant.RelayModeModerations
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
relayMode = RelayModeImagesGenerations
|
||||
relayMode = constant.RelayModeImagesGenerations
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||
relayMode = RelayModeEdits
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
relayMode = RelayModeAudio
|
||||
relayMode = constant.RelayModeEdits
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
|
||||
relayMode = constant.RelayModeAudioSpeech
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
|
||||
relayMode = constant.RelayModeAudioTranscription
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
relayMode = constant.RelayModeAudioTranslation
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var err *openai.ErrorWithStatusCode
|
||||
switch relayMode {
|
||||
case RelayModeImagesGenerations:
|
||||
err = relayImageHelper(c, relayMode)
|
||||
case RelayModeAudio:
|
||||
err = relayAudioHelper(c, relayMode)
|
||||
case constant.RelayModeImagesGenerations:
|
||||
err = controller.RelayImageHelper(c, relayMode)
|
||||
case constant.RelayModeAudioSpeech:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranslation:
|
||||
fallthrough
|
||||
case constant.RelayModeAudioTranscription:
|
||||
err = controller.RelayAudioHelper(c, relayMode)
|
||||
default:
|
||||
err = relayTextHelper(c, relayMode)
|
||||
err = controller.RelayTextHelper(c, relayMode)
|
||||
}
|
||||
if err != nil {
|
||||
requestId := c.GetString(common.RequestIdKey)
|
||||
@@ -206,17 +63,17 @@ func Relay(c *gin.Context) {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.StatusCode == http.StatusTooManyRequests {
|
||||
err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
err.Error.Message = "当前分组上游负载已饱和,请稍后再试"
|
||||
}
|
||||
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
|
||||
err.Error.Message = common.MessageWithRequestId(err.Error.Message, requestId)
|
||||
c.JSON(err.StatusCode, gin.H{
|
||||
"error": err.OpenAIError,
|
||||
"error": err.Error,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
|
||||
// https://platform.openai.com/docs/guides/error-codes/api-errors
|
||||
if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
|
||||
if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
|
||||
channelId := c.GetInt("channel_id")
|
||||
channelName := c.GetString("channel_name")
|
||||
disableChannel(channelId, channelName, err.Message)
|
||||
@@ -225,7 +82,7 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := OpenAIError{
|
||||
err := openai.Error{
|
||||
Message: "API not implemented",
|
||||
Type: "one_api_error",
|
||||
Param: "",
|
||||
@@ -237,7 +94,7 @@ func RelayNotImplemented(c *gin.Context) {
|
||||
}
|
||||
|
||||
func RelayNotFound(c *gin.Context) {
|
||||
err := OpenAIError{
|
||||
err := openai.Error{
|
||||
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
Param: "",
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -248,6 +249,29 @@ func GetUser(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetUserDashboard(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
now := time.Now()
|
||||
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
|
||||
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
|
||||
|
||||
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "无法获取统计信息",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": dashboards,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GenerateAccessToken(c *gin.Context) {
|
||||
id := c.GetInt("id")
|
||||
user, err := model.GetUserById(id, true)
|
||||
|
@@ -9,21 +9,21 @@ services:
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- ./data:/data
|
||||
- ./data/oneapi:/data
|
||||
- ./logs:/app/logs
|
||||
environment:
|
||||
- SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
|
||||
- REDIS_CONN_STRING=redis://redis
|
||||
- SESSION_SECRET=random_string # 修改为随机字符串
|
||||
- TZ=Asia/Shanghai
|
||||
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行
|
||||
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
|
||||
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- db
|
||||
healthcheck:
|
||||
test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
|
||||
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
@@ -32,3 +32,18 @@ services:
|
||||
image: redis:latest
|
||||
container_name: redis
|
||||
restart: always
|
||||
|
||||
db:
|
||||
image: mysql:8.2.0
|
||||
restart: always
|
||||
container_name: mysql
|
||||
volumes:
|
||||
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
|
||||
ports:
|
||||
- '3306:3306'
|
||||
environment:
|
||||
TZ: Asia/Shanghai # 设置时区
|
||||
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
|
||||
MYSQL_USER: oneapi # 创建专用用户
|
||||
MYSQL_PASSWORD: '123456' # 设置专用用户密码
|
||||
MYSQL_DATABASE: one-api # 自动创建数据库
|
14
go.mod
@@ -15,8 +15,11 @@ require (
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/pkoukk/tiktoken-go v0.1.5
|
||||
golang.org/x/crypto v0.9.0
|
||||
github.com/stretchr/testify v1.8.3
|
||||
golang.org/x/crypto v0.17.0
|
||||
golang.org/x/image v0.14.0
|
||||
gorm.io/driver/mysql v1.4.3
|
||||
gorm.io/driver/postgres v1.5.2
|
||||
gorm.io/driver/sqlite v1.4.3
|
||||
gorm.io/gorm v1.25.0
|
||||
)
|
||||
@@ -25,6 +28,7 @@ require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
@@ -49,13 +53,13 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/net v0.10.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/text v0.9.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.15.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
gorm.io/driver/postgres v1.5.2 // indirect
|
||||
)
|
||||
|
19
go.sum
@@ -150,11 +150,13 @@ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUu
|
||||
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
|
||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
|
||||
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
|
||||
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -162,14 +164,14 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@@ -198,7 +200,6 @@ gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBp
|
||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
|
||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
|
||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
|
||||
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
|
||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
|
||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
|
||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
|
251
i18n/en.json
@@ -86,6 +86,7 @@
|
||||
"该令牌已过期": "The token has expired",
|
||||
"该令牌额度已用尽": "The token quota has been used up",
|
||||
"无效的令牌": "Invalid token",
|
||||
"令牌验证失败": "Token verification failed",
|
||||
"id 或 userId 为空!": "id or userId is empty!",
|
||||
"quota 不能为负数!": "quota cannot be negative!",
|
||||
"令牌额度不足": "Insufficient token quota",
|
||||
@@ -119,6 +120,7 @@
|
||||
" 年 ": " y ",
|
||||
"未测试": "Not tested",
|
||||
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
|
||||
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
|
||||
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
|
||||
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
|
||||
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
|
||||
@@ -139,6 +141,7 @@
|
||||
"启用": "Enable",
|
||||
"编辑": "Edit",
|
||||
"添加新的渠道": "Add a new channel",
|
||||
"测试所有通道": "Test all channels",
|
||||
"测试所有已启用通道": "Test all enabled channels",
|
||||
"更新所有已启用通道余额": "Update the balance of all enabled channels",
|
||||
"刷新": "Refresh",
|
||||
@@ -456,6 +459,7 @@
|
||||
"使用明细(总消耗额度:{renderQuota(stat.quota)})": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
|
||||
"用户名称": "User Name",
|
||||
"令牌名称": "Token Name",
|
||||
"默认令牌": "Default Token",
|
||||
"留空则查询全部用户": "Leave blank to query all users",
|
||||
"留空则查询全部令牌": "Leave blank to query all tokens",
|
||||
"模型名称": "Model Name",
|
||||
@@ -524,5 +528,250 @@
|
||||
"模型版本": "Model version",
|
||||
"请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
|
||||
"点击查看": "click to view",
|
||||
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
|
||||
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!",
|
||||
"处理中...": "Processing...",
|
||||
"绑定成功!": "Binding successful!",
|
||||
"登录成功!": "Login successful!",
|
||||
"操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...",
|
||||
"出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...",
|
||||
"首页": "Home",
|
||||
"渠道": "Channel",
|
||||
"令牌": "API Keys",
|
||||
"兑换": "Redeem",
|
||||
"充值": "Recharge",
|
||||
"用户": "Users",
|
||||
"日志": "Logs",
|
||||
"设置": "Settings",
|
||||
"关于": "About",
|
||||
"聊天": "Chat",
|
||||
"注销成功!": "Logout successful!",
|
||||
"注销": "Log out",
|
||||
"登录": "Log in",
|
||||
"注册": "Sign up",
|
||||
"加载{name}中...": "Loading {name}...",
|
||||
"未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!",
|
||||
"请立刻修改默认密码!": "Please change the default password immediately!",
|
||||
"欢迎回来": "Welcome back",
|
||||
"没有账户?": "No account?",
|
||||
"立刻注册": "Sign up now",
|
||||
"用户名": "Username",
|
||||
"密码": "Password",
|
||||
"正在登录……": "Logging in...",
|
||||
"忘记密码": "Forgot password",
|
||||
"其他方式": "Other methods",
|
||||
"微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)",
|
||||
"验证码": "Verification code",
|
||||
"全部用户": "All users",
|
||||
"当前用户": "Current user",
|
||||
"全部": "All",
|
||||
"消费": "Consumption",
|
||||
"管理": "Management",
|
||||
"系统": "System",
|
||||
"未知": "Unknown",
|
||||
"其他模型": "Other models",
|
||||
"复制成功": "Copy successful",
|
||||
"使用明细": "Usages",
|
||||
"刷新": "Refresh",
|
||||
"收起面板": "Collapse panel",
|
||||
"展开面板": "Expand panel",
|
||||
"显示查询选项": "Show search options",
|
||||
"隐藏查询选项": "Hide search options",
|
||||
"用户名称": "User name",
|
||||
"可选值": "Optional values",
|
||||
"渠道 ID": "Channel ID",
|
||||
"令牌名称": "Key name",
|
||||
"模型名称": "Model name",
|
||||
"起始时间": "Start time",
|
||||
"结束时间": "End time",
|
||||
"查询": "Query",
|
||||
"隐藏条形图": "Hide bar chart",
|
||||
"显示条形图": "Show bar chart",
|
||||
"折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data",
|
||||
"总消耗": "Total consumption",
|
||||
"总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made",
|
||||
"{model.name}: {model.value} 次": "{model.name}: {model.value} times",
|
||||
"总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made",
|
||||
"总消耗额度": "Total consumption limit",
|
||||
"暂无数据": "No data available",
|
||||
"更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!",
|
||||
"复制用户名": "Copy username",
|
||||
"{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}",
|
||||
"共 0 条数据": "A total of 0 pieces of data",
|
||||
"选择明细分类": "Select detail category",
|
||||
"模型倍率": "model rate",
|
||||
"分组倍率": "group rate",
|
||||
"新密码已复制到剪贴板:": "New password has been copied to the clipboard:",
|
||||
"密码重置确认": "Password reset confirmation",
|
||||
"邮箱地址": "Email address",
|
||||
"新密码": "New password",
|
||||
"密码已复制到剪贴板:": "Password has been copied to the clipboard:",
|
||||
"密码重置完成": "Password reset complete",
|
||||
"提交": "Submit",
|
||||
"返回登录": "Return to login",
|
||||
"请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed",
|
||||
"重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!",
|
||||
"密码重置": "Password reset",
|
||||
"重试": "Retry",
|
||||
"组": "Group",
|
||||
"令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard",
|
||||
"邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard",
|
||||
"系统令牌已复制到剪切板": "System token has been copied to the clipboard",
|
||||
"请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
|
||||
"账户已删除!": "Account has been deleted!",
|
||||
"微信账户绑定成功!": "WeChat account binding successful!",
|
||||
"请稍后几秒重试,Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!",
|
||||
"验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!",
|
||||
"邮箱账户绑定成功!": "Email account binding successful!",
|
||||
"个人信息": "Personal information",
|
||||
"编辑个人信息": "Edit personal information",
|
||||
"生成系统访问令牌": "Generate system access token",
|
||||
"复制邀请链接": "Copy invitation link",
|
||||
"删除个人帐户": "Delete personal account",
|
||||
"普通用户": "Regular user",
|
||||
"管理员": "Administrator",
|
||||
"超级管理员": "Super administrator",
|
||||
"显示名称": "Display name",
|
||||
"GitHub 账号": "GitHub account",
|
||||
"微信账号": "WeChat account",
|
||||
"修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.",
|
||||
"可用模型": "Available models",
|
||||
"账号绑定": "Account binding",
|
||||
"绑定微信": "Bind WeChat",
|
||||
"绑定 GitHub": "Bind GitHub",
|
||||
"绑定邮箱": "Bind Email",
|
||||
"绑定": "Bind",
|
||||
"绑定邮箱地址": "Bind email address",
|
||||
"输入邮箱地址": "Enter email address",
|
||||
"重新发送": "Resend",
|
||||
"获取验证码": "Get verification code",
|
||||
"确认绑定": "Confirm binding",
|
||||
"取消": "Cancel",
|
||||
"危险操作": "Dangerous operation",
|
||||
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered",
|
||||
"输入你的账户名": "Enter your account name",
|
||||
"以确认删除": "To confirm deletion",
|
||||
"确认删除": "Confirm deletion",
|
||||
"未使用": "Not used",
|
||||
"已禁用": "Disabled",
|
||||
"已使用": "Used",
|
||||
"未知状态": "Unknown status",
|
||||
"操作成功完成!": "Operation successfully completed!",
|
||||
"搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...",
|
||||
"名称": "Name",
|
||||
"状态": "Status",
|
||||
"额度": "Quota",
|
||||
"创建时间": "Creation time",
|
||||
"兑换时间": "Redemption time",
|
||||
"操作": "Operation",
|
||||
"尚未兑换": "Not yet redeemed",
|
||||
"已复制到剪贴板!": "Copied to clipboard!",
|
||||
"无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.",
|
||||
"复制": "Copy",
|
||||
"删除": "Delete",
|
||||
"禁用": "Disable",
|
||||
"启用": "Enable",
|
||||
"编辑": "Edit",
|
||||
"添加新的兑换码": "Add new redemption code",
|
||||
"密码长度不得小于 8 位!": "Password length must not be less than 8 characters!",
|
||||
"两次输入的密码不一致": "The two passwords entered do not match",
|
||||
"注册成功!": "Registration successful!",
|
||||
"请填写注册邮箱!": "Please fill in the registration email!",
|
||||
"请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds",
|
||||
"验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!",
|
||||
"已有账户?": "Already have an account?",
|
||||
"请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)",
|
||||
"请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)",
|
||||
"请再次输入密码": "Please enter the password again",
|
||||
"请输入邮箱地址": "Please enter an email address",
|
||||
"秒后可重发": "Can be resent after seconds",
|
||||
"请输入邮箱验证码": "Please enter the email verification code",
|
||||
"已过期": "Expired",
|
||||
"已启用": "Enabled",
|
||||
"已耗尽": "Exhausted",
|
||||
"无": "None",
|
||||
"令牌密钥": "API Key",
|
||||
"令牌状态": "Key status",
|
||||
"已用额度": "Used quota",
|
||||
"剩余额度": "Remaining quota",
|
||||
"过期时间": "Expiration time",
|
||||
"你确定要删除这个令牌吗?": "Are you sure you want to delete this key?",
|
||||
"无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.",
|
||||
"无限制": "Unlimited",
|
||||
"永不过期": "Never expires",
|
||||
"使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.",
|
||||
"API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.",
|
||||
"创建令牌": "Create Key",
|
||||
"什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!",
|
||||
"你确定要删除该令牌吗": "Are you sure you want to delete this key",
|
||||
"导出令牌信息": "Export key information",
|
||||
"错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!",
|
||||
"错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!",
|
||||
"错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!",
|
||||
"本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!",
|
||||
"错误:": "Error:",
|
||||
"加载首页内容失败...": "Failed to load homepage content...",
|
||||
"系统状况": "System status",
|
||||
"系统信息": "System information",
|
||||
"系统信息总览": "System information overview",
|
||||
"名称:": "Name:",
|
||||
"版本:": "Version:",
|
||||
"源码:": "Source code:",
|
||||
"启动时间:": "Startup time:",
|
||||
"系统配置": "System configuration",
|
||||
"系统配置总览": "System configuration overview",
|
||||
"邮箱验证:": "Email verification:",
|
||||
"未启用": "Not enabled",
|
||||
"Turnstile 用户校验:": "Turnstile user verification:",
|
||||
"页面不存在": "Page does not exist",
|
||||
"请检查你的浏览器地址是否正确": "Please check if your browser address is correct",
|
||||
"个人设置": "Personal settings",
|
||||
"运营设置": "Operations settings",
|
||||
"系统设置": "System settings",
|
||||
"其他设置": "Other settings",
|
||||
"默认令牌": "Default key",
|
||||
"过期时间必须在当前时间之后!": "Expiration time must be after the current time!",
|
||||
"额度必须大于等于 0!": "Quota must be greater than or equal to 0!",
|
||||
"过期时间格式错误!": "Expiration time format error!",
|
||||
"创建令牌数量必须大于等于 1!": "The number of keys to create must be greater than or equal to 1!",
|
||||
"令牌修改成功": "API Key modification successful",
|
||||
"令牌创建成功": "API Key creation successful",
|
||||
"更新令牌信息": "Update key information",
|
||||
"创建新的令牌": "Create a new key",
|
||||
"请输入名称": "Please enter a name",
|
||||
"请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss,-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited",
|
||||
"无限额度": "Unlimited quota",
|
||||
"注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.",
|
||||
"等于": "Equals",
|
||||
"请输入额度(单位:token)": "Please enter the quota (unit: token)",
|
||||
"创建令牌数量": "Create key quantity",
|
||||
"请输入令牌数量": "Please enter the number of keys",
|
||||
"注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.",
|
||||
"我的令牌": "My keys",
|
||||
"请输入额度兑换码!": "Please enter the redeem code!",
|
||||
"充值成功!": "Recharge successful!",
|
||||
"请求失败": "Request failed",
|
||||
"超级管理员未设置充值链接!": "The super administrator did not set a recharge link!",
|
||||
"充值额度": "Recharge quota",
|
||||
"兑换中...": "Redeeming...",
|
||||
"请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.",
|
||||
"用户信息更新成功!": "User information updated successfully!",
|
||||
"更新用户信息": "Update user information",
|
||||
"请输入新的用户名": "Please enter a new username",
|
||||
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters",
|
||||
"请输入新的显示名称": "Please enter a new display name",
|
||||
"分组": "Group",
|
||||
"请选择分组": "Please select a group",
|
||||
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:",
|
||||
"请输入新的剩余额度": "Please enter a new remaining quota",
|
||||
"已绑定的 GitHub 账户": "Bound GitHub account",
|
||||
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified",
|
||||
"已绑定的微信账户": "Bound WeChat account",
|
||||
"已绑定的邮箱账户": "Bound email account",
|
||||
"新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5",
|
||||
"无法正常连接至服务器!": "Unable to connect to the server normally!",
|
||||
"提示:": "Input:",
|
||||
"补全:": "Output:",
|
||||
"搜索令牌名称": "Search key name",
|
||||
"测试所有渠道": "Test all channels",
|
||||
"更新已启用渠道余额": "Update the balance of enabled channels"
|
||||
}
|
||||
|
13
main.go
@@ -10,20 +10,18 @@ import (
|
||||
"one-api/controller"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/router"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//go:embed web/build
|
||||
//go:embed web/build/*
|
||||
var buildFS embed.FS
|
||||
|
||||
//go:embed web/build/index.html
|
||||
var indexPage []byte
|
||||
|
||||
func main() {
|
||||
common.SetupLogger()
|
||||
common.SysLog("One API " + common.Version + " started")
|
||||
common.SysLog(fmt.Sprintf("One API %s started", common.Version))
|
||||
if os.Getenv("GIN_MODE") != "debug" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
@@ -50,6 +48,7 @@ func main() {
|
||||
|
||||
// Initialize options
|
||||
model.InitOptionMap()
|
||||
common.SysLog(fmt.Sprintf("using theme %s", common.Theme))
|
||||
if common.RedisEnabled {
|
||||
// for compatibility with old versions
|
||||
common.MemoryCacheEnabled = true
|
||||
@@ -82,7 +81,7 @@ func main() {
|
||||
common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
|
||||
model.InitBatchUpdater()
|
||||
}
|
||||
controller.InitTokenEncoders()
|
||||
openai.InitTokenEncoders()
|
||||
|
||||
// Initialize HTTP server
|
||||
server := gin.New()
|
||||
@@ -95,7 +94,7 @@ func main() {
|
||||
store := cookie.NewStore([]byte(common.SessionSecret))
|
||||
server.Use(sessions.Sessions("session", store))
|
||||
|
||||
router.SetRouter(server, buildFS, indexPage)
|
||||
router.SetRouter(server, buildFS)
|
||||
var port = os.Getenv("PORT")
|
||||
if port == "" {
|
||||
port = strconv.Itoa(*common.Port)
|
||||
|
@@ -106,12 +106,6 @@ func TokenAuth() func(c *gin.Context) {
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
requestURL := c.Request.URL.String()
|
||||
consumeQuota := true
|
||||
if strings.HasPrefix(requestURL, "/v1/models") {
|
||||
consumeQuota = false
|
||||
}
|
||||
c.Set("consume_quota", consumeQuota)
|
||||
if len(parts) > 1 {
|
||||
if model.IsAdmin(token.UserId) {
|
||||
c.Set("channelId", parts[1])
|
||||
|
@@ -25,12 +25,12 @@ func Distribute() func(c *gin.Context) {
|
||||
if ok {
|
||||
id, err := strconv.Atoi(channelId.(string))
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
channel, err = model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 ID")
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
@@ -40,10 +40,7 @@ func Distribute() func(c *gin.Context) {
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
var err error
|
||||
if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
}
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
|
||||
return
|
||||
@@ -60,10 +57,10 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e"
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
@@ -90,8 +87,12 @@ func Distribute() func(c *gin.Context) {
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
|
28
middleware/recover.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func RelayPanicRecover() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
common.SysError(fmt.Sprintf("panic detected: %v", err))
|
||||
common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": gin.H{
|
||||
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
|
||||
"type": "one_api_panic",
|
||||
},
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
@@ -15,10 +15,17 @@ type Ability struct {
|
||||
|
||||
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
|
||||
ability := Ability{}
|
||||
groupCol := "`group`"
|
||||
trueVal := "1"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
trueVal = "true"
|
||||
}
|
||||
|
||||
var err error = nil
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model)
|
||||
channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite {
|
||||
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
|
||||
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
|
||||
if common.UsingSQLite || common.UsingPostgreSQL {
|
||||
err = channelQuery.Order("RANDOM()").First(&ability).Error
|
||||
} else {
|
||||
err = channelQuery.Order("RAND()").First(&ability).Error
|
||||
|
@@ -21,14 +21,18 @@ var (
|
||||
)
|
||||
|
||||
func CacheGetTokenByKey(key string) (*Token, error) {
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
var token Token
|
||||
if !common.RedisEnabled {
|
||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
return &token, err
|
||||
}
|
||||
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
|
||||
if err != nil {
|
||||
err := DB.Where("`key` = ?", key).First(&token).Error
|
||||
err := DB.Where(keyCol+" = ?", key).First(&token).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@@ -38,7 +38,11 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
|
||||
}
|
||||
|
||||
func SearchChannels(keyword string) (channels []*Channel, err error) {
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or `key` = ?", keyword, keyword+"%", keyword).Find(&channels).Error
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
|
||||
return channels, err
|
||||
}
|
||||
|
||||
@@ -53,17 +57,6 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func GetRandomChannel() (*Channel, error) {
|
||||
channel := Channel{}
|
||||
var err error = nil
|
||||
if common.UsingSQLite {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
|
||||
} else {
|
||||
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
|
||||
}
|
||||
return &channel, err
|
||||
}
|
||||
|
||||
func BatchInsertChannels(channels []Channel) error {
|
||||
var err error
|
||||
err = DB.Create(&channels).Error
|
||||
@@ -176,3 +169,13 @@ func updateChannelUsedQuota(id int, quota int) {
|
||||
common.SysError("failed to update channel used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteChannelByStatus(status int64) (int64, error) {
|
||||
result := DB.Where("status = ?", status).Delete(&Channel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
func DeleteDisabledChannel() (int64, error) {
|
||||
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
52
model/log.go
@@ -3,19 +3,20 @@ package model
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"one-api/common"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type Log struct {
|
||||
Id int `json:"id"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
|
||||
Type int `json:"type" gorm:"index"`
|
||||
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"`
|
||||
Type int `json:"type" gorm:"index:idx_created_at_type"`
|
||||
Content string `json:"content"`
|
||||
Username string `json:"username" gorm:"index;default:''"`
|
||||
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"`
|
||||
TokenName string `json:"token_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;default:''"`
|
||||
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"`
|
||||
Quota int `json:"quota" gorm:"default:0"`
|
||||
PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
|
||||
CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
|
||||
@@ -94,7 +95,7 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
|
||||
tx = tx.Where("created_at <= ?", endTimestamp)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel = ?", channel)
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
|
||||
return logs, err
|
||||
@@ -151,7 +152,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
|
||||
tx = tx.Where("model_name = ?", modelName)
|
||||
}
|
||||
if channel != 0 {
|
||||
tx = tx.Where("channel = ?", channel)
|
||||
tx = tx.Where("channel_id = ?", channel)
|
||||
}
|
||||
tx.Where("type = ?", LogTypeConsume).Scan("a)
|
||||
return quota
|
||||
@@ -182,3 +183,40 @@ func DeleteOldLog(targetTimestamp int64) (int64, error) {
|
||||
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
|
||||
return result.RowsAffected, result.Error
|
||||
}
|
||||
|
||||
type LogStatistic struct {
|
||||
Day string `gorm:"column:day"`
|
||||
ModelName string `gorm:"column:model_name"`
|
||||
RequestCount int `gorm:"column:request_count"`
|
||||
Quota int `gorm:"column:quota"`
|
||||
PromptTokens int `gorm:"column:prompt_tokens"`
|
||||
CompletionTokens int `gorm:"column:completion_tokens"`
|
||||
}
|
||||
|
||||
func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
|
||||
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
|
||||
|
||||
if common.UsingPostgreSQL {
|
||||
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
|
||||
}
|
||||
|
||||
if common.UsingSQLite {
|
||||
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
|
||||
}
|
||||
|
||||
err = DB.Raw(`
|
||||
SELECT `+groupSelect+`,
|
||||
model_name, count(1) as request_count,
|
||||
sum(quota) as quota,
|
||||
sum(prompt_tokens) as prompt_tokens,
|
||||
sum(completion_tokens) as completion_tokens
|
||||
FROM logs
|
||||
WHERE type=2
|
||||
AND user_id= ?
|
||||
AND created_at BETWEEN ? AND ?
|
||||
GROUP BY day, model_name
|
||||
ORDER BY day, model_name
|
||||
`, userId, start, end).Scan(&LogStatistics).Error
|
||||
|
||||
return LogStatistics, err
|
||||
}
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlite"
|
||||
@@ -15,7 +16,7 @@ var DB *gorm.DB
|
||||
|
||||
func createRootAccountIfNeed() error {
|
||||
var user User
|
||||
//if user.Status != common.UserStatusEnabled {
|
||||
//if user.Status != util.UserStatusEnabled {
|
||||
if err := DB.First(&user).Error; err != nil {
|
||||
common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
|
||||
hashedPassword, err := common.Password2Hash("123456")
|
||||
@@ -42,6 +43,7 @@ func chooseDB() (*gorm.DB, error) {
|
||||
if strings.HasPrefix(dsn, "postgres://") {
|
||||
// Use PostgreSQL
|
||||
common.SysLog("using PostgreSQL as database")
|
||||
common.UsingPostgreSQL = true
|
||||
return gorm.Open(postgres.New(postgres.Config{
|
||||
DSN: dsn,
|
||||
PreferSimpleProtocol: true, // disables implicit prepared statement usage
|
||||
@@ -58,7 +60,8 @@ func chooseDB() (*gorm.DB, error) {
|
||||
// Use SQLite
|
||||
common.SysLog("SQL_DSN not set, using SQLite as database")
|
||||
common.UsingSQLite = true
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
|
||||
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
|
||||
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
|
||||
PrepareStmt: true, // precompile SQL
|
||||
})
|
||||
}
|
||||
@@ -81,6 +84,7 @@ func InitDB() (err error) {
|
||||
if !common.IsMasterNode {
|
||||
return nil
|
||||
}
|
||||
common.SysLog("database migration started")
|
||||
err = db.AutoMigrate(&Channel{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@@ -34,6 +34,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
|
||||
common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
|
||||
common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
|
||||
common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled)
|
||||
common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
|
||||
common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
|
||||
common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
|
||||
@@ -71,6 +72,7 @@ func InitOptionMap() {
|
||||
common.OptionMap["ChatLink"] = common.ChatLink
|
||||
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["Theme"] = common.Theme
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
@@ -147,6 +149,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.EmailDomainRestrictionEnabled = boolValue
|
||||
case "AutomaticDisableChannelEnabled":
|
||||
common.AutomaticDisableChannelEnabled = boolValue
|
||||
case "AutomaticEnableChannelEnabled":
|
||||
common.AutomaticEnableChannelEnabled = boolValue
|
||||
case "ApproximateTokenEnabled":
|
||||
common.ApproximateTokenEnabled = boolValue
|
||||
case "LogConsumeEnabled":
|
||||
@@ -217,6 +221,8 @@ func updateOptionMap(key string, value string) (err error) {
|
||||
common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
|
||||
case "QuotaPerUnit":
|
||||
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
|
||||
case "Theme":
|
||||
common.Theme = value
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
@@ -50,8 +50,13 @@ func Redeem(key string, userId int) (quota int, err error) {
|
||||
}
|
||||
redemption := &Redemption{}
|
||||
|
||||
keyCol := "`key`"
|
||||
if common.UsingPostgreSQL {
|
||||
keyCol = `"key"`
|
||||
}
|
||||
|
||||
err = DB.Transaction(func(tx *gorm.DB) error {
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where("`key` = ?", key).First(redemption).Error
|
||||
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error
|
||||
if err != nil {
|
||||
return errors.New("无效的兑换码")
|
||||
}
|
||||
|
@@ -38,39 +38,43 @@ func ValidateUserToken(key string) (token *Token, err error) {
|
||||
return nil, errors.New("未提供令牌")
|
||||
}
|
||||
token, err = CacheGetTokenByKey(key)
|
||||
if err == nil {
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
} else if token.Status == common.TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
if err != nil {
|
||||
common.SysError("CacheGetTokenByKey failed: " + err.Error())
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("无效的令牌")
|
||||
}
|
||||
if token.Status != common.TokenStatusEnabled {
|
||||
return nil, errors.New("该令牌状态不可用")
|
||||
}
|
||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||
if !common.RedisEnabled {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||
if !common.RedisEnabled {
|
||||
// in this case, we can make sure the token is exhausted
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
}
|
||||
return token, nil
|
||||
return nil, errors.New("令牌验证失败")
|
||||
}
|
||||
return nil, errors.New("无效的令牌")
|
||||
if token.Status == common.TokenStatusExhausted {
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
} else if token.Status == common.TokenStatusExpired {
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if token.Status != common.TokenStatusEnabled {
|
||||
return nil, errors.New("该令牌状态不可用")
|
||||
}
|
||||
if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
|
||||
if !common.RedisEnabled {
|
||||
token.Status = common.TokenStatusExpired
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌已过期")
|
||||
}
|
||||
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
|
||||
if !common.RedisEnabled {
|
||||
// in this case, we can make sure the token is exhausted
|
||||
token.Status = common.TokenStatusExhausted
|
||||
err := token.SelectUpdate()
|
||||
if err != nil {
|
||||
common.SysError("failed to update token status" + err.Error())
|
||||
}
|
||||
}
|
||||
return nil, errors.New("该令牌额度已用尽")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func GetTokenByIds(id int, userId int) (*Token, error) {
|
||||
|
@@ -15,7 +15,7 @@ type User struct {
|
||||
Username string `json:"username" gorm:"unique;index" validate:"max=12"`
|
||||
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
|
||||
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, common
|
||||
Role int `json:"role" gorm:"type:int;default:1"` // admin, util
|
||||
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
|
||||
Email string `json:"email" gorm:"index" validate:"max=50"`
|
||||
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
|
||||
@@ -42,7 +42,11 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
||||
}
|
||||
|
||||
func SearchUsers(keyword string) (users []*User, err error) {
|
||||
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
if !common.UsingPostgreSQL {
|
||||
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
} else {
|
||||
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
|
||||
}
|
||||
return users, err
|
||||
}
|
||||
|
||||
@@ -137,7 +141,15 @@ func (user *User) ValidateAndFill() (err error) {
|
||||
if user.Username == "" || password == "" {
|
||||
return errors.New("用户名或密码为空")
|
||||
}
|
||||
DB.Where(User{Username: user.Username}).First(user)
|
||||
err = DB.Where("username = ?", user.Username).First(user).Error
|
||||
if err != nil {
|
||||
// we must make sure check username firstly
|
||||
// consider this case: a malicious user set his username as other's email
|
||||
err := DB.Where("email = ?", user.Username).First(user).Error
|
||||
if err != nil {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
}
|
||||
}
|
||||
okay := common.ValidatePasswordAndHash(password, user.Password)
|
||||
if !okay || user.Status != common.UserStatusEnabled {
|
||||
return errors.New("用户名或密码错误,或用户已被封禁")
|
||||
@@ -266,7 +278,12 @@ func GetUserEmail(id int) (email string, err error) {
|
||||
}
|
||||
|
||||
func GetUserGroup(id int) (group string, err error) {
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select("`group`").Find(&group).Error
|
||||
groupCol := "`group`"
|
||||
if common.UsingPostgreSQL {
|
||||
groupCol = `"group"`
|
||||
}
|
||||
|
||||
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
|
||||
return group, err
|
||||
}
|
||||
|
||||
@@ -309,7 +326,8 @@ func GetRootUserEmail() (email string) {
|
||||
|
||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
|
||||
if common.BatchUpdateEnabled {
|
||||
addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
|
||||
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
|
||||
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
|
||||
return
|
||||
}
|
||||
updateUserUsedQuotaAndRequestCount(id, quota, 1)
|
||||
@@ -327,6 +345,24 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserUsedQuota(id int, quota int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Updates(
|
||||
map[string]interface{}{
|
||||
"used_quota": gorm.Expr("used_quota + ?", quota),
|
||||
},
|
||||
).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user used quota: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func updateUserRequestCount(id int, count int) {
|
||||
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
|
||||
if err != nil {
|
||||
common.SysError("failed to update user request count: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func GetUsernameById(id int) (username string) {
|
||||
DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
|
||||
return username
|
||||
|
@@ -6,13 +6,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
|
||||
|
||||
const (
|
||||
BatchUpdateTypeUserQuota = iota
|
||||
BatchUpdateTypeTokenQuota
|
||||
BatchUpdateTypeUsedQuotaAndRequestCount
|
||||
BatchUpdateTypeUsedQuota
|
||||
BatchUpdateTypeChannelUsedQuota
|
||||
BatchUpdateTypeRequestCount
|
||||
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
|
||||
)
|
||||
|
||||
var batchUpdateStores []map[int]int
|
||||
@@ -51,7 +51,7 @@ func batchUpdate() {
|
||||
store := batchUpdateStores[i]
|
||||
batchUpdateStores[i] = make(map[int]int)
|
||||
batchUpdateLocks[i].Unlock()
|
||||
|
||||
// TODO: maybe we can combine updates with same key?
|
||||
for key, value := range store {
|
||||
switch i {
|
||||
case BatchUpdateTypeUserQuota:
|
||||
@@ -64,8 +64,10 @@ func batchUpdate() {
|
||||
if err != nil {
|
||||
common.SysError("failed to batch update token quota: " + err.Error())
|
||||
}
|
||||
case BatchUpdateTypeUsedQuotaAndRequestCount:
|
||||
updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
|
||||
case BatchUpdateTypeUsedQuota:
|
||||
updateUserUsedQuota(key, value)
|
||||
case BatchUpdateTypeRequestCount:
|
||||
updateUserRequestCount(key, value)
|
||||
case BatchUpdateTypeChannelUsedQuota:
|
||||
updateChannelUsedQuota(key, value)
|
||||
}
|
||||
|
9
pull_request_template.md
Normal file
@@ -0,0 +1,9 @@
|
||||
[//]: # (请按照以下格式关联 issue)
|
||||
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR,谢谢)
|
||||
[//]: # (项目维护者一般仅在周末处理 PR,因此如若未能及时回复希望能理解)
|
||||
[//]: # (开发者交流群:910657413)
|
||||
[//]: # (请在提交 PR 之前删除上面的注释)
|
||||
|
||||
close #issue_number
|
||||
|
||||
我已确认该 PR 已自测通过,相关截图如下:
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package aiproxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,56 +8,27 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
|
||||
|
||||
type AIProxyLibraryRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
LibraryId string `json:"libraryId"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryError struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryDocument struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type AIProxyLibraryResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Answer string `json:"answer"`
|
||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||
AIProxyLibraryError
|
||||
}
|
||||
|
||||
type AIProxyLibraryStreamResponse struct {
|
||||
Content string `json:"content"`
|
||||
Finish bool `json:"finish"`
|
||||
Model string `json:"model"`
|
||||
Documents []AIProxyLibraryDocument `json:"documents"`
|
||||
}
|
||||
|
||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *LibraryRequest {
|
||||
query := ""
|
||||
if len(request.Messages) != 0 {
|
||||
query = request.Messages[len(request.Messages)-1].Content
|
||||
query = request.Messages[len(request.Messages)-1].StringContent()
|
||||
}
|
||||
return &AIProxyLibraryRequest{
|
||||
return &LibraryRequest{
|
||||
Model: request.Model,
|
||||
Stream: request.Stream,
|
||||
Query: query,
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
|
||||
if len(documents) == 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -68,52 +39,52 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
|
||||
return content
|
||||
}
|
||||
|
||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
|
||||
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
|
||||
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: content,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
|
||||
choice.FinishReason = &stopFinishReason
|
||||
return &ChatCompletionsStreamResponse{
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = response.Content
|
||||
return &ChatCompletionsStreamResponse{
|
||||
return &openai.ChatCompletionsStreamResponse{
|
||||
Id: common.GetUUID(),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: response.Model,
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
}
|
||||
|
||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -143,12 +114,12 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
var documents []AIProxyLibraryDocument
|
||||
common.SetEventStreamHeaders(c)
|
||||
var documents []LibraryDocument
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var AIProxyLibraryResponse AIProxyLibraryStreamResponse
|
||||
var AIProxyLibraryResponse LibraryStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -179,28 +150,28 @@ func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIEr
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var AIProxyLibraryResponse AIProxyLibraryResponse
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var AIProxyLibraryResponse LibraryResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if AIProxyLibraryResponse.ErrCode != 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: AIProxyLibraryResponse.Message,
|
||||
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
|
||||
Code: AIProxyLibraryResponse.ErrCode,
|
||||
@@ -211,7 +182,7 @@ func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
||||
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
32
relay/channel/aiproxy/model.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package aiproxy
|
||||
|
||||
type LibraryRequest struct {
|
||||
Model string `json:"model"`
|
||||
Query string `json:"query"`
|
||||
LibraryId string `json:"libraryId"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type LibraryError struct {
|
||||
ErrCode int `json:"errCode"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type LibraryDocument struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type LibraryResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Answer string `json:"answer"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
LibraryError
|
||||
}
|
||||
|
||||
type LibraryStreamResponse struct {
|
||||
Content string `json:"content"`
|
||||
Finish bool `json:"finish"`
|
||||
Model string `json:"model"`
|
||||
Documents []LibraryDocument `json:"documents"`
|
||||
}
|
253
relay/channel/ali/main.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package ali
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
|
||||
|
||||
const EnableSearchModelSuffix = "-internet"
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: strings.ToLower(message.Role),
|
||||
})
|
||||
}
|
||||
enableSearch := false
|
||||
aliModel := request.Model
|
||||
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
|
||||
enableSearch = true
|
||||
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
|
||||
}
|
||||
return &ChatRequest{
|
||||
Model: aliModel,
|
||||
Input: Input{
|
||||
Messages: messages,
|
||||
},
|
||||
Parameters: Parameters{
|
||||
EnableSearch: enableSearch,
|
||||
IncrementalOutput: request.Stream,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Model: "text-embedding-v1",
|
||||
Input: struct {
|
||||
Texts []string `json:"texts"`
|
||||
}{
|
||||
Texts: request.ParseInput(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var aliResponse EmbeddingResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if aliResponse.Code != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
|
||||
Model: "text-embedding-v1",
|
||||
Usage: openai.Usage{TotalTokens: response.Usage.TotalTokens},
|
||||
}
|
||||
|
||||
for _, item := range response.Output.Embeddings {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: `embedding`,
|
||||
Index: item.TextIndex,
|
||||
Embedding: item.Embedding,
|
||||
})
|
||||
}
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Output.Text,
|
||||
},
|
||||
FinishReason: response.Output.FinishReason,
|
||||
}
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.RequestId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: openai.Usage{
|
||||
PromptTokens: response.Usage.InputTokens,
|
||||
CompletionTokens: response.Usage.OutputTokens,
|
||||
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
|
||||
},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = aliResponse.Output.Text
|
||||
if aliResponse.Output.FinishReason != "null" {
|
||||
finishReason := aliResponse.Output.FinishReason
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: aliResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "qwen",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
//lastResponseText := ""
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var aliResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &aliResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
if aliResponse.Usage.OutputTokens != 0 {
|
||||
usage.PromptTokens = aliResponse.Usage.InputTokens
|
||||
usage.CompletionTokens = aliResponse.Usage.OutputTokens
|
||||
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
|
||||
}
|
||||
response := streamResponseAli2OpenAI(&aliResponse)
|
||||
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
|
||||
//lastResponseText = aliResponse.Output.Text
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var aliResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &aliResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if aliResponse.Code != "" {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: aliResponse.Message,
|
||||
Type: aliResponse.Code,
|
||||
Param: aliResponse.RequestId,
|
||||
Code: aliResponse.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseAli2OpenAI(&aliResponse)
|
||||
fullTextResponse.Model = "qwen"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
71
relay/channel/ali/model.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package ali
|
||||
|
||||
type Message struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type Input struct {
|
||||
//Prompt string `json:"prompt"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Parameters struct {
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Seed uint64 `json:"seed,omitempty"`
|
||||
EnableSearch bool `json:"enable_search,omitempty"`
|
||||
IncrementalOutput bool `json:"incremental_output,omitempty"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input Input `json:"input"`
|
||||
Parameters Parameters `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input struct {
|
||||
Texts []string `json:"texts"`
|
||||
} `json:"input"`
|
||||
Parameters *struct {
|
||||
TextType string `json:"text_type,omitempty"`
|
||||
} `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
TextIndex int `json:"text_index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Output struct {
|
||||
Embeddings []Embedding `json:"embeddings"`
|
||||
} `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RequestId string `json:"request_id"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Output Output `json:"output"`
|
||||
Usage Usage `json:"usage"`
|
||||
Error
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,37 +8,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ClaudeMetadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type ClaudeRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ClaudeError struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type ClaudeResponse struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error ClaudeError `json:"error"`
|
||||
}
|
||||
|
||||
func stopReasonClaude2OpenAI(reason string) string {
|
||||
switch reason {
|
||||
case "stop_sequence":
|
||||
@@ -50,8 +23,8 @@ func stopReasonClaude2OpenAI(reason string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
claudeRequest := ClaudeRequest{
|
||||
func ConvertRequest(textRequest openai.GeneralOpenAIRequest) *Request {
|
||||
claudeRequest := Request{
|
||||
Model: textRequest.Model,
|
||||
Prompt: "",
|
||||
MaxTokensToSample: textRequest.MaxTokens,
|
||||
@@ -70,7 +43,9 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
} else if message.Role == "assistant" {
|
||||
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
|
||||
} else if message.Role == "system" {
|
||||
prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
|
||||
if prompt == "" {
|
||||
prompt = message.StringContent()
|
||||
}
|
||||
}
|
||||
}
|
||||
prompt += "\n\nAssistant:"
|
||||
@@ -78,40 +53,40 @@ func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
|
||||
return &claudeRequest
|
||||
}
|
||||
|
||||
func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = claudeResponse.Completion
|
||||
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
|
||||
if finishReason != "null" {
|
||||
choice.FinishReason = &finishReason
|
||||
}
|
||||
var response ChatCompletionsStreamResponse
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = claudeResponse.Model
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
|
||||
choice := OpenAITextResponseChoice{
|
||||
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
|
||||
Name: nil,
|
||||
},
|
||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
@@ -141,13 +116,13 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
var claudeResponse ClaudeResponse
|
||||
var claudeResponse Response
|
||||
err := json.Unmarshal([]byte(data), &claudeResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -171,28 +146,28 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var claudeResponse ClaudeResponse
|
||||
var claudeResponse Response
|
||||
err = json.Unmarshal(responseBody, &claudeResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if claudeResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: claudeResponse.Error.Message,
|
||||
Type: claudeResponse.Error.Type,
|
||||
Param: "",
|
||||
@@ -202,8 +177,9 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseClaude2OpenAI(&claudeResponse)
|
||||
completionTokens := countTokenText(claudeResponse.Completion, model)
|
||||
usage := Usage{
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(claudeResponse.Completion, model)
|
||||
usage := openai.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
@@ -211,7 +187,7 @@ func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
29
relay/channel/anthropic/model.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package anthropic
|
||||
|
||||
type Metadata struct {
|
||||
UserId string `json:"user_id"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokensToSample int `json:"max_tokens_to_sample"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
//Metadata `json:"metadata,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Completion string `json:"completion"`
|
||||
StopReason string `json:"stop_reason"`
|
||||
Model string `json:"model"`
|
||||
Error Error `json:"error"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -37,59 +40,15 @@ type BaiduError struct {
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
type BaiduChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduChatStreamResponse struct {
|
||||
BaiduChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type BaiduEmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []BaiduEmbeddingData `json:"data"`
|
||||
Usage Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type BaiduAccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
var baiduTokenStore sync.Map
|
||||
|
||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
messages := make([]BaiduMessage, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: "assistant",
|
||||
@@ -98,7 +57,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
} else {
|
||||
messages = append(messages, BaiduMessage{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -108,56 +67,56 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
|
||||
choice := OpenAITextResponseChoice{
|
||||
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Result,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Id,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Usage,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = baiduResponse.Result
|
||||
if baiduResponse.IsEnd {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: baiduResponse.Id,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: baiduResponse.Created,
|
||||
Model: "ernie-bot",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
|
||||
return &BaiduEmbeddingRequest{
|
||||
func ConvertEmbeddingRequest(request openai.GeneralOpenAIRequest) *EmbeddingRequest {
|
||||
return &EmbeddingRequest{
|
||||
Input: request.ParseInput(),
|
||||
}
|
||||
}
|
||||
|
||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
|
||||
openAIEmbeddingResponse := OpenAIEmbeddingResponse{
|
||||
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
|
||||
openAIEmbeddingResponse := openai.EmbeddingResponse{
|
||||
Object: "list",
|
||||
Data: make([]OpenAIEmbeddingResponseItem, 0, len(response.Data)),
|
||||
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
|
||||
Model: "baidu-embedding",
|
||||
Usage: response.Usage,
|
||||
}
|
||||
for _, item := range response.Data {
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
|
||||
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
|
||||
Object: item.Object,
|
||||
Index: item.Index,
|
||||
Embedding: item.Embedding,
|
||||
@@ -166,8 +125,8 @@ func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbe
|
||||
return &openAIEmbeddingResponse
|
||||
}
|
||||
|
||||
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage openai.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -194,11 +153,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var baiduResponse BaiduChatStreamResponse
|
||||
var baiduResponse ChatStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &baiduResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -224,28 +183,28 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var baiduResponse BaiduChatResponse
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var baiduResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -255,9 +214,10 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
|
||||
fullTextResponse.Model = "ernie-bot"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -265,23 +225,23 @@ func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var baiduResponse BaiduEmbeddingResponse
|
||||
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var baiduResponse EmbeddingResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &baiduResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if baiduResponse.ErrorMsg != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: baiduResponse.ErrorMsg,
|
||||
Type: "baidu_error",
|
||||
Param: "",
|
||||
@@ -293,7 +253,7 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
||||
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
@@ -301,10 +261,10 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func getBaiduAccessToken(apiKey string) (string, error) {
|
||||
func GetAccessToken(apiKey string) (string, error) {
|
||||
if val, ok := baiduTokenStore.Load(apiKey); ok {
|
||||
var accessToken BaiduAccessToken
|
||||
if accessToken, ok = val.(BaiduAccessToken); ok {
|
||||
var accessToken AccessToken
|
||||
if accessToken, ok = val.(AccessToken); ok {
|
||||
// soon this will expire
|
||||
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
|
||||
go func() {
|
||||
@@ -319,12 +279,12 @@ func getBaiduAccessToken(apiKey string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
if accessToken == nil {
|
||||
return "", errors.New("getBaiduAccessToken return a nil token")
|
||||
return "", errors.New("GetAccessToken return a nil token")
|
||||
}
|
||||
return (*accessToken).AccessToken, nil
|
||||
}
|
||||
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
|
||||
parts := strings.Split(apiKey, "|")
|
||||
if len(parts) != 2 {
|
||||
return nil, errors.New("invalid baidu apikey")
|
||||
@@ -336,13 +296,13 @@ func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/json")
|
||||
req.Header.Add("Accept", "application/json")
|
||||
res, err := impatientHTTPClient.Do(req)
|
||||
res, err := util.ImpatientHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
var accessToken BaiduAccessToken
|
||||
var accessToken AccessToken
|
||||
err = json.NewDecoder(res.Body).Decode(&accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
50
relay/channel/baidu/model.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package baidu
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/openai"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Result string `json:"result"`
|
||||
IsTruncated bool `json:"is_truncated"`
|
||||
NeedClearHistory bool `json:"need_clear_history"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type ChatStreamResponse struct {
|
||||
ChatResponse
|
||||
SentenceId int `json:"sentence_id"`
|
||||
IsEnd bool `json:"is_end"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Input []string `json:"input"`
|
||||
}
|
||||
|
||||
type EmbeddingData struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Data []EmbeddingData `json:"data"`
|
||||
Usage openai.Usage `json:"usage"`
|
||||
BaiduError
|
||||
}
|
||||
|
||||
type AccessToken struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
Error string `json:"error,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
ExpiresAt time.Time `json:"-"`
|
||||
}
|
299
relay/channel/google/gemini.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
|
||||
|
||||
const (
|
||||
GeminiVisionMaxImageNum = 16
|
||||
)
|
||||
|
||||
// Setting safety to the lowest possible values since Gemini is already powerless enough
|
||||
func ConvertGeminiRequest(textRequest openai.GeneralOpenAIRequest) *GeminiChatRequest {
|
||||
geminiRequest := GeminiChatRequest{
|
||||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
|
||||
SafetySettings: []GeminiChatSafetySettings{
|
||||
{
|
||||
Category: "HARM_CATEGORY_HARASSMENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_HATE_SPEECH",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
{
|
||||
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
Threshold: common.GeminiSafetySetting,
|
||||
},
|
||||
},
|
||||
GenerationConfig: GeminiChatGenerationConfig{
|
||||
Temperature: textRequest.Temperature,
|
||||
TopP: textRequest.TopP,
|
||||
MaxOutputTokens: textRequest.MaxTokens,
|
||||
},
|
||||
}
|
||||
if textRequest.Functions != nil {
|
||||
geminiRequest.Tools = []GeminiChatTools{
|
||||
{
|
||||
FunctionDeclarations: textRequest.Functions,
|
||||
},
|
||||
}
|
||||
}
|
||||
shouldAddDummyModelMessage := false
|
||||
for _, message := range textRequest.Messages {
|
||||
content := GeminiChatContent{
|
||||
Role: message.Role,
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: message.StringContent(),
|
||||
},
|
||||
},
|
||||
}
|
||||
openaiContent := message.ParseContent()
|
||||
var parts []GeminiPart
|
||||
imageNum := 0
|
||||
for _, part := range openaiContent {
|
||||
if part.Type == openai.ContentTypeText {
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: part.Text,
|
||||
})
|
||||
} else if part.Type == openai.ContentTypeImageURL {
|
||||
imageNum += 1
|
||||
if imageNum > GeminiVisionMaxImageNum {
|
||||
continue
|
||||
}
|
||||
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
content.Parts = parts
|
||||
|
||||
// there's no assistant role in gemini and API shall vomit if Role is not user or model
|
||||
if content.Role == "assistant" {
|
||||
content.Role = "model"
|
||||
}
|
||||
// Converting system prompt to prompt from user for the same reason
|
||||
if content.Role == "system" {
|
||||
content.Role = "user"
|
||||
shouldAddDummyModelMessage = true
|
||||
}
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, content)
|
||||
|
||||
// If a system message is the last message, we need to add a dummy model message to make gemini happy
|
||||
if shouldAddDummyModelMessage {
|
||||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
|
||||
Role: "model",
|
||||
Parts: []GeminiPart{
|
||||
{
|
||||
Text: "Okay",
|
||||
},
|
||||
},
|
||||
})
|
||||
shouldAddDummyModelMessage = false
|
||||
}
|
||||
}
|
||||
|
||||
return &geminiRequest
|
||||
}
|
||||
|
||||
type GeminiChatResponse struct {
|
||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||
}
|
||||
|
||||
func (g *GeminiChatResponse) GetResponseText() string {
|
||||
if g == nil {
|
||||
return ""
|
||||
}
|
||||
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
|
||||
return g.Candidates[0].Content.Parts[0].Text
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type GeminiChatCandidate struct {
|
||||
Content GeminiChatContent `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int64 `json:"index"`
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetyRating struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
}
|
||||
|
||||
type GeminiChatPromptFeedback struct {
|
||||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
|
||||
}
|
||||
|
||||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: "",
|
||||
},
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
if len(candidate.Content.Parts) > 0 {
|
||||
choice.Message.Content = candidate.Content.Parts[0].Text
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = geminiResponse.GetResponseText()
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "gemini"
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
data = strings.TrimSpace(data)
|
||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "\"text\": \"")
|
||||
data = strings.TrimSuffix(data, "\"")
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
// this is used to prevent annoying \ related format bug
|
||||
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
|
||||
type dummyStruct struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
var dummy dummyStruct
|
||||
err := json.Unmarshal([]byte(data), &dummy)
|
||||
responseText += dummy.Content
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = dummy.Content
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "gemini-pro",
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func GeminiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = json.Unmarshal(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
Param: "",
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), model)
|
||||
usage := openai.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
80
relay/channel/google/model.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type GeminiChatRequest struct {
|
||||
Contents []GeminiChatContent `json:"contents"`
|
||||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"`
|
||||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"`
|
||||
Tools []GeminiChatTools `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatContent struct {
|
||||
Role string `json:"role,omitempty"`
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiChatSafetySettings struct {
|
||||
Category string `json:"category"`
|
||||
Threshold string `json:"threshold"`
|
||||
}
|
||||
|
||||
type GeminiChatTools struct {
|
||||
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []openai.Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package google
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -7,47 +7,14 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||
|
||||
type PaLMChatMessage struct {
|
||||
Author string `json:"author"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type PaLMFilter struct {
|
||||
Reason string `json:"reason"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type PaLMPrompt struct {
|
||||
Messages []PaLMChatMessage `json:"messages"`
|
||||
}
|
||||
|
||||
type PaLMChatRequest struct {
|
||||
Prompt PaLMPrompt `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK int `json:"topK,omitempty"`
|
||||
}
|
||||
|
||||
type PaLMError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type PaLMChatResponse struct {
|
||||
Candidates []PaLMChatMessage `json:"candidates"`
|
||||
Messages []Message `json:"messages"`
|
||||
Filters []PaLMFilter `json:"filters"`
|
||||
Error PaLMError `json:"error"`
|
||||
}
|
||||
|
||||
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
func ConvertPaLMRequest(textRequest openai.GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
palmRequest := PaLMChatRequest{
|
||||
Prompt: PaLMPrompt{
|
||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
||||
@@ -59,7 +26,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
}
|
||||
for _, message := range textRequest.Messages {
|
||||
palmMessage := PaLMChatMessage{
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
}
|
||||
if message.Role == "user" {
|
||||
palmMessage.Author = "0"
|
||||
@@ -71,14 +38,14 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
|
||||
return &palmRequest
|
||||
}
|
||||
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
|
||||
}
|
||||
for i, candidate := range response.Candidates {
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: candidate.Content,
|
||||
},
|
||||
@@ -89,20 +56,20 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
if len(palmResponse.Candidates) > 0 {
|
||||
choice.Delta.Content = palmResponse.Candidates[0].Content
|
||||
}
|
||||
choice.FinishReason = &stopFinishReason
|
||||
var response ChatCompletionsStreamResponse
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
var response openai.ChatCompletionsStreamResponse
|
||||
response.Object = "chat.completion.chunk"
|
||||
response.Model = "palm2"
|
||||
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
|
||||
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
|
||||
return &response
|
||||
}
|
||||
|
||||
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
|
||||
func PaLMStreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||
createdTime := common.GetTimestamp()
|
||||
@@ -143,7 +110,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
||||
dataChan <- string(jsonResponse)
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -156,28 +123,28 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func PaLMHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var palmResponse PaLMChatResponse
|
||||
err = json.Unmarshal(responseBody, &palmResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: palmResponse.Error.Message,
|
||||
Type: palmResponse.Error.Status,
|
||||
Param: "",
|
||||
@@ -187,8 +154,9 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := Usage{
|
||||
fullTextResponse.Model = model
|
||||
completionTokens := openai.CountTokenText(palmResponse.Candidates[0].Content, model)
|
||||
usage := openai.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
@@ -196,7 +164,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
||||
fullTextResponse.Usage = usage
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
6
relay/channel/openai/constant.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package openai
|
||||
|
||||
const (
|
||||
ContentTypeText = "text"
|
||||
ContentTypeImageURL = "image_url"
|
||||
)
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
|
||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*ErrorWithStatusCode, string) {
|
||||
responseText := ""
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
@@ -41,7 +42,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
case constant.RelayModeChatCompletions:
|
||||
var streamResponse ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
@@ -51,7 +52,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
for _, choice := range streamResponse.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
case RelayModeCompletions:
|
||||
case constant.RelayModeCompletions:
|
||||
var streamResponse CompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &streamResponse)
|
||||
if err != nil {
|
||||
@@ -66,7 +67,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -83,56 +84,55 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var textResponse TextResponse
|
||||
if consumeQuota {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*ErrorWithStatusCode, *Usage) {
|
||||
var textResponse SlimTextResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if textResponse.Error.Type != "" {
|
||||
return &ErrorWithStatusCode{
|
||||
Error: textResponse.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
// Reset response body
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// So the HTTPClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err := io.Copy(c.Writer, resp.Body)
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
|
||||
if textResponse.Usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range textResponse.Choices {
|
||||
completionTokens += countTokenText(choice.Message.Content, model)
|
||||
completionTokens += CountTokenText(choice.Message.StringContent(), model)
|
||||
}
|
||||
textResponse.Usage = Usage{
|
||||
PromptTokens: promptTokens,
|
283
relay/channel/openai/model.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package openai
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content any `json:"content"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
}
|
||||
|
||||
type ImageURL struct {
|
||||
Url string `json:"url,omitempty"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
type TextContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type ImageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type OpenAIMessageContent struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Text string `json:"text"`
|
||||
ImageURL *ImageURL `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
func (m Message) IsStringContent() bool {
|
||||
_, ok := m.Content.(string)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m Message) StringContent() string {
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
return content
|
||||
}
|
||||
contentList, ok := m.Content.([]any)
|
||||
if ok {
|
||||
var contentStr string
|
||||
for _, contentItem := range contentList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if contentMap["type"] == ContentTypeText {
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentStr += subStr
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentStr
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m Message) ParseContent() []OpenAIMessageContent {
|
||||
var contentList []OpenAIMessageContent
|
||||
content, ok := m.Content.(string)
|
||||
if ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeText,
|
||||
Text: content,
|
||||
})
|
||||
return contentList
|
||||
}
|
||||
anyList, ok := m.Content.([]any)
|
||||
if ok {
|
||||
for _, contentItem := range anyList {
|
||||
contentMap, ok := contentItem.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
switch contentMap["type"] {
|
||||
case ContentTypeText:
|
||||
if subStr, ok := contentMap["text"].(string); ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeText,
|
||||
Text: subStr,
|
||||
})
|
||||
}
|
||||
case ContentTypeImageURL:
|
||||
if subObj, ok := contentMap["image_url"].(map[string]any); ok {
|
||||
contentList = append(contentList, OpenAIMessageContent{
|
||||
Type: ContentTypeImageURL,
|
||||
ImageURL: &ImageURL{
|
||||
Url: subObj["url"].(string),
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return contentList
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ResponseFormat struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
}
|
||||
|
||||
type GeneralOpenAIRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions any `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
if r.Input == nil {
|
||||
return nil
|
||||
}
|
||||
var input []string
|
||||
switch r.Input.(type) {
|
||||
case string:
|
||||
input = []string{r.Input.(string)}
|
||||
case []any:
|
||||
input = make([]string, 0, len(r.Input.([]any)))
|
||||
for _, item := range r.Input.([]any) {
|
||||
if str, ok := item.(string); ok {
|
||||
input = append(input, str)
|
||||
}
|
||||
}
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type TextRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Prompt string `json:"prompt"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
//Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperJSONResponse struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
type WhisperVerboseJSONResponse struct {
|
||||
Task string `json:"task,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Duration float64 `json:"duration,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
Segments []Segment `json:"segments,omitempty"`
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
Id int `json:"id"`
|
||||
Seek int `json:"seek"`
|
||||
Start float64 `json:"start"`
|
||||
End float64 `json:"end"`
|
||||
Text string `json:"text"`
|
||||
Tokens []int `json:"tokens"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
AvgLogprob float64 `json:"avg_logprob"`
|
||||
CompressionRatio float64 `json:"compression_ratio"`
|
||||
NoSpeechProb float64 `json:"no_speech_prob"`
|
||||
}
|
||||
|
||||
type TextToSpeechRequest struct {
|
||||
Model string `json:"model" binding:"required"`
|
||||
Input string `json:"input" binding:"required"`
|
||||
Voice string `json:"voice" binding:"required"`
|
||||
Speed float64 `json:"speed"`
|
||||
ResponseFormat string `json:"response_format"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
type ErrorWithStatusCode struct {
|
||||
Error
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type SlimTextResponse struct {
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
Error Error `json:"error"`
|
||||
}
|
||||
|
||||
type TextResponseChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
type TextResponse struct {
|
||||
Id string `json:"id"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Choices []TextResponseChoice `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type EmbeddingResponseItem struct {
|
||||
Object string `json:"object"`
|
||||
Index int `json:"index"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Object string `json:"object"`
|
||||
Data []EmbeddingResponseItem `json:"data"`
|
||||
Model string `json:"model"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ImageResponse struct {
|
||||
Created int `json:"created"`
|
||||
Data []struct {
|
||||
Url string `json:"url"`
|
||||
}
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponseChoice struct {
|
||||
Delta struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason *string `json:"finish_reason,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionsStreamResponse struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
|
||||
}
|
||||
|
||||
type CompletionsStreamResponse struct {
|
||||
Choices []struct {
|
||||
Text string `json:"text"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
205
relay/channel/openai/token.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/common/image"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
|
||||
var defaultTokenEncoder *tiktoken.Tiktoken
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
|
||||
}
|
||||
defaultTokenEncoder = gpt35TokenEncoder
|
||||
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4")
|
||||
if err != nil {
|
||||
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
|
||||
}
|
||||
for model, _ := range common.ModelRatio {
|
||||
if strings.HasPrefix(model, "gpt-3.5") {
|
||||
tokenEncoderMap[model] = gpt35TokenEncoder
|
||||
} else if strings.HasPrefix(model, "gpt-4") {
|
||||
tokenEncoderMap[model] = gpt4TokenEncoder
|
||||
} else {
|
||||
tokenEncoderMap[model] = nil
|
||||
}
|
||||
}
|
||||
common.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
|
||||
tokenEncoder, ok := tokenEncoderMap[model]
|
||||
if ok && tokenEncoder != nil {
|
||||
return tokenEncoder
|
||||
}
|
||||
if ok {
|
||||
tokenEncoder, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
|
||||
tokenEncoder = defaultTokenEncoder
|
||||
}
|
||||
tokenEncoderMap[model] = tokenEncoder
|
||||
return tokenEncoder
|
||||
}
|
||||
return defaultTokenEncoder
|
||||
}
|
||||
|
||||
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
||||
if common.ApproximateTokenEnabled {
|
||||
return int(float64(len(text)) * 0.38)
|
||||
}
|
||||
return len(tokenEncoder.Encode(text, nil, nil))
|
||||
}
|
||||
|
||||
func CountTokenMessages(messages []Message, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
switch v := message.Content.(type) {
|
||||
case string:
|
||||
tokenNum += getTokenNum(tokenEncoder, v)
|
||||
case []any:
|
||||
for _, it := range v {
|
||||
m := it.(map[string]any)
|
||||
switch m["type"] {
|
||||
case "text":
|
||||
tokenNum += getTokenNum(tokenEncoder, m["text"].(string))
|
||||
case "image_url":
|
||||
imageUrl, ok := m["image_url"].(map[string]any)
|
||||
if ok {
|
||||
url := imageUrl["url"].(string)
|
||||
detail := ""
|
||||
if imageUrl["detail"] != nil {
|
||||
detail = imageUrl["detail"].(string)
|
||||
}
|
||||
imageTokens, err := countImageTokens(url, detail)
|
||||
if err != nil {
|
||||
common.SysError("error counting image tokens: " + err.Error())
|
||||
} else {
|
||||
tokenNum += imageTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum
|
||||
}
|
||||
|
||||
const (
|
||||
lowDetailCost = 85
|
||||
highDetailCostPerTile = 170
|
||||
additionalCost = 85
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/guides/vision/calculating-costs
|
||||
// https://github.com/openai/openai-cookbook/blob/05e3f9be4c7a2ae7ecf029a7c32065b024730ebe/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
func countImageTokens(url string, detail string) (_ int, err error) {
|
||||
var fetchSize = true
|
||||
var width, height int
|
||||
// Reference: https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding
|
||||
// detail == "auto" is undocumented on how it works, it just said the model will use the auto setting which will look at the image input size and decide if it should use the low or high setting.
|
||||
// According to the official guide, "low" disable the high-res model,
|
||||
// and only receive low-res 512px x 512px version of the image, indicating
|
||||
// that image is treated as low-res when size is smaller than 512px x 512px,
|
||||
// then we can assume that image size larger than 512px x 512px is treated
|
||||
// as high-res. Then we have the following logic:
|
||||
// if detail == "" || detail == "auto" {
|
||||
// width, height, err = image.GetImageSize(url)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// fetchSize = false
|
||||
// // not sure if this is correct
|
||||
// if width > 512 || height > 512 {
|
||||
// detail = "high"
|
||||
// } else {
|
||||
// detail = "low"
|
||||
// }
|
||||
// }
|
||||
|
||||
// However, in my test, it seems to be always the same as "high".
|
||||
// The following image, which is 125x50, is still treated as high-res, taken
|
||||
// 255 tokens in the response of non-stream chat completion api.
|
||||
// https://upload.wikimedia.org/wikipedia/commons/1/10/18_Infantry_Division_Messina.jpg
|
||||
if detail == "" || detail == "auto" {
|
||||
// assume by test, not sure if this is correct
|
||||
detail = "high"
|
||||
}
|
||||
switch detail {
|
||||
case "low":
|
||||
return lowDetailCost, nil
|
||||
case "high":
|
||||
if fetchSize {
|
||||
width, height, err = image.GetImageSize(url)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if width > 2048 || height > 2048 { // max(width, height) > 2048
|
||||
ratio := float64(2048) / math.Max(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
if width > 768 && height > 768 { // min(width, height) > 768
|
||||
ratio := float64(768) / math.Min(float64(width), float64(height))
|
||||
width = int(float64(width) * ratio)
|
||||
height = int(float64(height) * ratio)
|
||||
}
|
||||
numSquares := int(math.Ceil(float64(width)/512) * math.Ceil(float64(height)/512))
|
||||
result := numSquares*highDetailCostPerTile + additionalCost
|
||||
return result, nil
|
||||
default:
|
||||
return 0, errors.New("invalid detail option")
|
||||
}
|
||||
}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CountTokenText(v, model)
|
||||
case []string:
|
||||
text := ""
|
||||
for _, s := range v {
|
||||
text += s
|
||||
}
|
||||
return CountTokenText(text, model)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func CountTokenText(text string, model string) int {
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
return getTokenNum(tokenEncoder, text)
|
||||
}
|
13
relay/channel/openai/util.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package openai
|
||||
|
||||
func ErrorWrapper(err error, code string, statusCode int) *ErrorWithStatusCode {
|
||||
Error := Error{
|
||||
Message: err.Error(),
|
||||
Type: "one_api_error",
|
||||
Code: code,
|
||||
}
|
||||
return &ErrorWithStatusCode{
|
||||
Error: Error,
|
||||
StatusCode: statusCode,
|
||||
}
|
||||
}
|
232
relay/channel/tencent/main.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// https://cloud.tencent.com/document/product/1729/97732
|
||||
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for i := 0; i < len(request.Messages); i++ {
|
||||
message := request.Messages[i]
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
continue
|
||||
}
|
||||
messages = append(messages, Message{
|
||||
Content: message.StringContent(),
|
||||
Role: message.Role,
|
||||
})
|
||||
}
|
||||
stream := 0
|
||||
if request.Stream {
|
||||
stream = 1
|
||||
}
|
||||
return &ChatRequest{
|
||||
Timestamp: common.GetTimestamp(),
|
||||
Expired: common.GetTimestamp() + 24*60*60,
|
||||
QueryID: common.GetUUID(),
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
Stream: stream,
|
||||
Messages: messages,
|
||||
}
|
||||
}
|
||||
|
||||
func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Usage: response.Usage,
|
||||
}
|
||||
if len(response.Choices) > 0 {
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Choices[0].Messages.Content,
|
||||
},
|
||||
FinishReason: response.Choices[0].FinishReason,
|
||||
}
|
||||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "tencent-hunyuan",
|
||||
}
|
||||
if len(TencentResponse.Choices) > 0 {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
|
||||
if TencentResponse.Choices[0].FinishReason == "stop" {
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response.Choices = append(response.Choices, choice)
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, string) {
|
||||
var responseText string
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 5 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:5] != "data:" {
|
||||
continue
|
||||
}
|
||||
data = data[5:]
|
||||
dataChan <- data
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
var TencentResponse ChatResponse
|
||||
err := json.Unmarshal([]byte(data), &TencentResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
response := streamResponseTencent2OpenAI(&TencentResponse)
|
||||
if len(response.Choices) != 0 {
|
||||
responseText += response.Choices[0].Delta.Content
|
||||
}
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
common.SysError("error marshalling stream response: " + err.Error())
|
||||
return true
|
||||
}
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case <-stopChan:
|
||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||
return false
|
||||
}
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
return nil, responseText
|
||||
}
|
||||
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var TencentResponse ChatResponse
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &TencentResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if TencentResponse.Error.Code != 0 {
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: TencentResponse.Error.Message,
|
||||
Code: TencentResponse.Error.Code,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseTencent2OpenAI(&TencentResponse)
|
||||
fullTextResponse.Model = "hunyuan"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
return nil, &fullTextResponse.Usage
|
||||
}
|
||||
|
||||
func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
|
||||
parts := strings.Split(config, "|")
|
||||
if len(parts) != 3 {
|
||||
err = errors.New("invalid tencent config")
|
||||
return
|
||||
}
|
||||
appId, err = strconv.ParseInt(parts[0], 10, 64)
|
||||
secretId = parts[1]
|
||||
secretKey = parts[2]
|
||||
return
|
||||
}
|
||||
|
||||
func GetSign(req ChatRequest, secretKey string) string {
|
||||
params := make([]string, 0)
|
||||
params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
|
||||
params = append(params, "secret_id="+req.SecretId)
|
||||
params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
|
||||
params = append(params, "query_id="+req.QueryID)
|
||||
params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
|
||||
params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
|
||||
params = append(params, "stream="+strconv.Itoa(req.Stream))
|
||||
params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
|
||||
|
||||
var messageStr string
|
||||
for _, msg := range req.Messages {
|
||||
messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
|
||||
}
|
||||
messageStr = strings.TrimSuffix(messageStr, ",")
|
||||
params = append(params, "messages=["+messageStr+"]")
|
||||
|
||||
sort.Sort(sort.StringSlice(params))
|
||||
url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
|
||||
mac := hmac.New(sha1.New, []byte(secretKey))
|
||||
signURL := url
|
||||
mac.Write([]byte(signURL))
|
||||
sign := mac.Sum([]byte(nil))
|
||||
return base64.StdEncoding.EncodeToString(sign)
|
||||
}
|
63
relay/channel/tencent/model.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package tencent
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
AppId int64 `json:"app_id"` // 腾讯云账号的 APPID
|
||||
SecretId string `json:"secret_id"` // 官网 SecretId
|
||||
// Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。
|
||||
// 例如1529223702,如果与当前时间相差过大,会引起签名过期错误
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
// Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值,
|
||||
// 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天
|
||||
Expired int64 `json:"expired"`
|
||||
QueryID string `json:"query_id"` //请求 Id,用于问题排查
|
||||
// Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定
|
||||
// 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果
|
||||
// 建议该参数和 top_p 只设置1个,不要同时更改 top_p
|
||||
Temperature float64 `json:"temperature"`
|
||||
// TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强
|
||||
// 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果
|
||||
// 建议该参数和 temperature 只设置1个,不要同时更改
|
||||
TopP float64 `json:"top_p"`
|
||||
// Stream 0:同步,1:流式 (默认,协议:SSE)
|
||||
// 同步请求超时:60s,如果内容较长建议使用流式
|
||||
Stream int `json:"stream"`
|
||||
// Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列
|
||||
// 输入 content 总数最大支持 3000 token。
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type ResponseChoices struct {
|
||||
FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
|
||||
Messages Message `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
Delta Message `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Choices []ResponseChoices `json:"choices,omitempty"` // 结果
|
||||
Created string `json:"created,omitempty"` // unix 时间戳的字符串
|
||||
Id string `json:"id,omitempty"` // 会话 id
|
||||
Usage openai.Usage `json:"usage,omitempty"` // token 数量
|
||||
Error Error `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值
|
||||
Note string `json:"note,omitempty"` // 注释
|
||||
ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@@ -19,82 +21,26 @@ import (
|
||||
// https://console.xfyun.cn/services/cbm
|
||||
// https://www.xfyun.cn/doc/spark/Web.html
|
||||
|
||||
type XunfeiMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type XunfeiChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []XunfeiMessage `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type XunfeiChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []XunfeiChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
|
||||
messages := make([]XunfeiMessage, 0, len(request.Messages))
|
||||
func requestOpenAI2Xunfei(request openai.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, XunfeiMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "assistant",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, XunfeiMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
xunfeiRequest := XunfeiChatRequest{}
|
||||
xunfeiRequest := ChatRequest{}
|
||||
xunfeiRequest.Header.AppId = xunfeiAppId
|
||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||
@@ -104,49 +50,49 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma
|
||||
return &xunfeiRequest
|
||||
}
|
||||
|
||||
func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse {
|
||||
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
|
||||
if len(response.Payload.Choices.Text) == 0 {
|
||||
response.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
response.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
choice := OpenAITextResponseChoice{
|
||||
choice := openai.TextResponseChoice{
|
||||
Index: 0,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Payload.Choices.Text[0].Content,
|
||||
},
|
||||
FinishReason: stopFinishReason,
|
||||
FinishReason: constant.StopFinishReason,
|
||||
}
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: []OpenAITextResponseChoice{choice},
|
||||
Choices: []openai.TextResponseChoice{choice},
|
||||
Usage: response.Payload.Usage.Text,
|
||||
}
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *XunfeiChatResponse) *ChatCompletionsStreamResponse {
|
||||
func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []XunfeiChatResponseTextItem{
|
||||
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
if xunfeiResponse.Payload.Choices.Status == 2 {
|
||||
choice.FinishReason = &stopFinishReason
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
}
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "SparkDesk",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
@@ -177,14 +123,14 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
return callUrl
|
||||
}
|
||||
|
||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func StreamHandler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
setEventStreamHeaders(c)
|
||||
var usage Usage
|
||||
common.SetEventStreamHeaders(c)
|
||||
var usage openai.Usage
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case xunfeiResponse := <-dataChan:
|
||||
@@ -207,19 +153,22 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
func Handler(c *gin.Context, textRequest openai.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret)
|
||||
dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "make xunfei request err", http.StatusInternalServerError), nil
|
||||
}
|
||||
var usage Usage
|
||||
var usage openai.Usage
|
||||
var content string
|
||||
var xunfeiResponse XunfeiChatResponse
|
||||
var xunfeiResponse ChatResponse
|
||||
stop := false
|
||||
for !stop {
|
||||
select {
|
||||
case xunfeiResponse = <-dataChan:
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
continue
|
||||
}
|
||||
content += xunfeiResponse.Payload.Choices.Text[0].Content
|
||||
usage.PromptTokens += xunfeiResponse.Payload.Usage.Text.PromptTokens
|
||||
usage.CompletionTokens += xunfeiResponse.Payload.Usage.Text.CompletionTokens
|
||||
@@ -227,20 +176,26 @@ func xunfeiHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId strin
|
||||
case stop = <-stopChan:
|
||||
}
|
||||
}
|
||||
|
||||
if len(xunfeiResponse.Payload.Choices.Text) == 0 {
|
||||
xunfeiResponse.Payload.Choices.Text = []ChatResponseTextItem{
|
||||
{
|
||||
Content: "",
|
||||
},
|
||||
}
|
||||
}
|
||||
xunfeiResponse.Payload.Choices.Text[0].Content = content
|
||||
|
||||
response := responseXunfei2OpenAI(&xunfeiResponse)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, &usage
|
||||
}
|
||||
|
||||
func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId string) (chan XunfeiChatResponse, chan bool, error) {
|
||||
func xunfeiMakeRequest(textRequest openai.GeneralOpenAIRequest, domain, authUrl, appId string) (chan ChatResponse, chan bool, error) {
|
||||
d := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
@@ -254,7 +209,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
dataChan := make(chan XunfeiChatResponse)
|
||||
dataChan := make(chan ChatResponse)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
@@ -263,7 +218,7 @@ func xunfeiMakeRequest(textRequest GeneralOpenAIRequest, domain, authUrl, appId
|
||||
common.SysError("error reading stream response: " + err.Error())
|
||||
break
|
||||
}
|
||||
var response XunfeiChatResponse
|
||||
var response ChatResponse
|
||||
err = json.Unmarshal(msg, &response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -295,8 +250,8 @@ func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string) (string,
|
||||
common.SysLog("api_version not found, use default: " + apiVersion)
|
||||
}
|
||||
domain := "general"
|
||||
if apiVersion == "v2.1" {
|
||||
domain = "generalv2"
|
||||
if apiVersion != "v1.1" {
|
||||
domain += strings.Split(apiVersion, ".")[0]
|
||||
}
|
||||
authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret)
|
||||
return domain, authUrl
|
61
relay/channel/xunfei/model.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package xunfei
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/openai"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Header struct {
|
||||
AppId string `json:"app_id"`
|
||||
} `json:"header"`
|
||||
Parameter struct {
|
||||
Chat struct {
|
||||
Domain string `json:"domain,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Auditing bool `json:"auditing,omitempty"`
|
||||
} `json:"chat"`
|
||||
} `json:"parameter"`
|
||||
Payload struct {
|
||||
Message struct {
|
||||
Text []Message `json:"text"`
|
||||
} `json:"message"`
|
||||
} `json:"payload"`
|
||||
}
|
||||
|
||||
type ChatResponseTextItem struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Header struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Sid string `json:"sid"`
|
||||
Status int `json:"status"`
|
||||
} `json:"header"`
|
||||
Payload struct {
|
||||
Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []ChatResponseTextItem `json:"text"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
//Text struct {
|
||||
// QuestionTokens string `json:"question_tokens"`
|
||||
// PromptTokens string `json:"prompt_tokens"`
|
||||
// CompletionTokens string `json:"completion_tokens"`
|
||||
// TotalTokens string `json:"total_tokens"`
|
||||
//} `json:"text"`
|
||||
Text openai.Usage `json:"text"`
|
||||
} `json:"usage"`
|
||||
} `json:"payload"`
|
||||
}
|
@@ -1,4 +1,4 @@
|
||||
package controller
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,53 +20,13 @@ import (
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
|
||||
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
|
||||
|
||||
type ZhipuMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ZhipuRequest struct {
|
||||
Prompt []ZhipuMessage `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ZhipuResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []ZhipuMessage `json:"choices"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type ZhipuResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ZhipuResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type ZhipuStreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type zhipuTokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
||||
|
||||
var zhipuTokens sync.Map
|
||||
var expSeconds int64 = 24 * 3600
|
||||
|
||||
func getZhipuToken(apikey string) string {
|
||||
func GetToken(apikey string) string {
|
||||
data, ok := zhipuTokens.Load(apikey)
|
||||
if ok {
|
||||
tokenData := data.(zhipuTokenData)
|
||||
tokenData := data.(tokenData)
|
||||
if time.Now().Before(tokenData.ExpiryTime) {
|
||||
return tokenData.Token
|
||||
}
|
||||
@@ -100,7 +62,7 @@ func getZhipuToken(apikey string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
zhipuTokens.Store(apikey, zhipuTokenData{
|
||||
zhipuTokens.Store(apikey, tokenData{
|
||||
Token: tokenString,
|
||||
ExpiryTime: expiryTime,
|
||||
})
|
||||
@@ -108,26 +70,26 @@ func getZhipuToken(apikey string) string {
|
||||
return tokenString
|
||||
}
|
||||
|
||||
func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
messages := make([]ZhipuMessage, 0, len(request.Messages))
|
||||
func ConvertRequest(request openai.GeneralOpenAIRequest) *Request {
|
||||
messages := make([]Message, 0, len(request.Messages))
|
||||
for _, message := range request.Messages {
|
||||
if message.Role == "system" {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "system",
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
messages = append(messages, ZhipuMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: "user",
|
||||
Content: "Okay",
|
||||
})
|
||||
} else {
|
||||
messages = append(messages, ZhipuMessage{
|
||||
messages = append(messages, Message{
|
||||
Role: message.Role,
|
||||
Content: message.Content,
|
||||
Content: message.StringContent(),
|
||||
})
|
||||
}
|
||||
}
|
||||
return &ZhipuRequest{
|
||||
return &Request{
|
||||
Prompt: messages,
|
||||
Temperature: request.Temperature,
|
||||
TopP: request.TopP,
|
||||
@@ -135,18 +97,18 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
|
||||
}
|
||||
}
|
||||
|
||||
func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
||||
fullTextResponse := OpenAITextResponse{
|
||||
func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
|
||||
fullTextResponse := openai.TextResponse{
|
||||
Id: response.Data.TaskId,
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)),
|
||||
Usage: response.Data.Usage,
|
||||
}
|
||||
for i, choice := range response.Data.Choices {
|
||||
openaiChoice := OpenAITextResponseChoice{
|
||||
openaiChoice := openai.TextResponseChoice{
|
||||
Index: i,
|
||||
Message: Message{
|
||||
Message: openai.Message{
|
||||
Role: choice.Role,
|
||||
Content: strings.Trim(choice.Content, "\""),
|
||||
},
|
||||
@@ -160,34 +122,34 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
|
||||
return &fullTextResponse
|
||||
}
|
||||
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = zhipuResponse
|
||||
response := ChatCompletionsStreamResponse{
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response
|
||||
}
|
||||
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
|
||||
var choice ChatCompletionsStreamResponseChoice
|
||||
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *openai.Usage) {
|
||||
var choice openai.ChatCompletionsStreamResponseChoice
|
||||
choice.Delta.Content = ""
|
||||
choice.FinishReason = &stopFinishReason
|
||||
response := ChatCompletionsStreamResponse{
|
||||
choice.FinishReason = &constant.StopFinishReason
|
||||
response := openai.ChatCompletionsStreamResponse{
|
||||
Id: zhipuResponse.RequestId,
|
||||
Object: "chat.completion.chunk",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: "chatglm",
|
||||
Choices: []ChatCompletionsStreamResponseChoice{choice},
|
||||
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
|
||||
}
|
||||
return &response, &zhipuResponse.Usage
|
||||
}
|
||||
|
||||
func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var usage *Usage
|
||||
func StreamHandler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var usage *openai.Usage
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
@@ -224,7 +186,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
setEventStreamHeaders(c)
|
||||
common.SetEventStreamHeaders(c)
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
@@ -237,7 +199,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||
return true
|
||||
case data := <-metaChan:
|
||||
var zhipuResponse ZhipuStreamMetaResponse
|
||||
var zhipuResponse StreamMetaResponse
|
||||
err := json.Unmarshal([]byte(data), &zhipuResponse)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
@@ -259,28 +221,28 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
|
||||
})
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
|
||||
var zhipuResponse ZhipuResponse
|
||||
func Handler(c *gin.Context, resp *http.Response) (*openai.ErrorWithStatusCode, *openai.Usage) {
|
||||
var zhipuResponse Response
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &zhipuResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
if !zhipuResponse.Success {
|
||||
return &OpenAIErrorWithStatusCode{
|
||||
OpenAIError: OpenAIError{
|
||||
return &openai.ErrorWithStatusCode{
|
||||
Error: openai.Error{
|
||||
Message: zhipuResponse.Msg,
|
||||
Type: "zhipu_error",
|
||||
Param: "",
|
||||
@@ -290,9 +252,10 @@ func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCo
|
||||
}, nil
|
||||
}
|
||||
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
|
||||
fullTextResponse.Model = "chatglm"
|
||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
46
relay/channel/zhipu/model.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package zhipu
|
||||
|
||||
import (
|
||||
"one-api/relay/channel/openai"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
Prompt []Message `json:"prompt"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
RequestId string `json:"request_id,omitempty"`
|
||||
Incremental bool `json:"incremental,omitempty"`
|
||||
}
|
||||
|
||||
type ResponseData struct {
|
||||
TaskId string `json:"task_id"`
|
||||
RequestId string `json:"request_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
Choices []Message `json:"choices"`
|
||||
openai.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Data ResponseData `json:"data"`
|
||||
}
|
||||
|
||||
type StreamMetaResponse struct {
|
||||
RequestId string `json:"request_id"`
|
||||
TaskId string `json:"task_id"`
|
||||
TaskStatus string `json:"task_status"`
|
||||
openai.Usage `json:"usage"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
Token string
|
||||
ExpiryTime time.Time
|
||||
}
|
16
relay/constant/main.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package constant
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
var StopFinishReason = "stop"
|
265
relay/controller/audio.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
audioModel := "whisper-1"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
var ttsRequest openai.TextToSpeechRequest
|
||||
if relayMode == constant.RelayModeAudioSpeech {
|
||||
// Read JSON
|
||||
err := common.UnmarshalBodyReusable(c, &ttsRequest)
|
||||
// Check if JSON is valid
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest)
|
||||
}
|
||||
audioModel = ttsRequest.Model
|
||||
// Check if text is too long 4096
|
||||
if len(ttsRequest.Input) > 4096 {
|
||||
return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(audioModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
var quota int
|
||||
var preConsumedQuota int
|
||||
switch relayMode {
|
||||
case constant.RelayModeAudioSpeech:
|
||||
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
preConsumedQuota = int(float64(common.PreConsumedQuota) * ratio)
|
||||
}
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// Check if user quota is enough
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
}
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[audioModel] != "" {
|
||||
audioModel = modelMap[audioModel]
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
|
||||
}
|
||||
|
||||
requestBody := &bytes.Buffer{}
|
||||
_, err = io.Copy(requestBody, c.Request.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode == constant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
req.Header.Set("api-key", apiKey)
|
||||
req.ContentLength = c.Request.ContentLength
|
||||
} else {
|
||||
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode != constant.RelayModeAudioSpeech {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var openAIErr openai.SlimTextResponse
|
||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
||||
if openAIErr.Error.Message != "" {
|
||||
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
var text string
|
||||
switch responseFormat {
|
||||
case "json":
|
||||
text, err = getTextFromJSON(responseBody)
|
||||
case "text":
|
||||
text, err = getTextFromText(responseBody)
|
||||
case "srt":
|
||||
text, err = getTextFromSRT(responseBody)
|
||||
case "verbose_json":
|
||||
text, err = getTextFromVerboseJSON(responseBody)
|
||||
case "vtt":
|
||||
text, err = getTextFromVTT(responseBody)
|
||||
default:
|
||||
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
||||
}
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||
}
|
||||
quota = openai.CountTokenText(text, audioModel)
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if preConsumedQuota > 0 {
|
||||
// we need to roll back the pre-consumed quota
|
||||
defer func(ctx context.Context) {
|
||||
go func() {
|
||||
// negative means add quota back for token & user
|
||||
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error()))
|
||||
}
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
|
||||
}(c.Request.Context())
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTextFromVTT(body []byte) (string, error) {
|
||||
return getTextFromSRT(body)
|
||||
}
|
||||
|
||||
func getTextFromVerboseJSON(body []byte) (string, error) {
|
||||
var whisperResponse openai.WhisperVerboseJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
||||
|
||||
func getTextFromSRT(body []byte) (string, error) {
|
||||
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
||||
var builder strings.Builder
|
||||
var textLine bool
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if textLine {
|
||||
builder.WriteString(line)
|
||||
textLine = false
|
||||
continue
|
||||
} else if strings.Contains(line, "-->") {
|
||||
textLine = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func getTextFromText(body []byte) (string, error) {
|
||||
return strings.TrimSuffix(string(body), "\n"), nil
|
||||
}
|
||||
|
||||
func getTextFromJSON(body []byte) (string, error) {
|
||||
var whisperResponse openai.WhisperJSONResponse
|
||||
if err := json.Unmarshal(body, &whisperResponse); err != nil {
|
||||
return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err)
|
||||
}
|
||||
return whisperResponse.Text, nil
|
||||
}
|
224
relay/controller/image.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
min := common.DalleGenerationImageAmounts[element][0]
|
||||
max := common.DalleGenerationImageAmounts[element][1]
|
||||
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
imageModel := "dall-e-2"
|
||||
imageSize := "1024x1024"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
userId := c.GetInt("id")
|
||||
group := c.GetString("group")
|
||||
|
||||
var imageRequest openai.ImageRequest
|
||||
err := common.UnmarshalBodyReusable(c, &imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if imageRequest.N == 0 {
|
||||
imageRequest.N = 1
|
||||
}
|
||||
|
||||
// Size validation
|
||||
if imageRequest.Size != "" {
|
||||
imageSize = imageRequest.Size
|
||||
}
|
||||
|
||||
// Model validation
|
||||
if imageRequest.Model != "" {
|
||||
imageModel = imageRequest.Model
|
||||
}
|
||||
|
||||
imageCostRatio, hasValidSize := common.DalleSizeRatios[imageModel][imageSize]
|
||||
|
||||
// Check if model is supported
|
||||
if hasValidSize {
|
||||
if imageRequest.Quality == "hd" && imageModel == "dall-e-3" {
|
||||
if imageSize == "1024x1024" {
|
||||
imageCostRatio *= 2
|
||||
} else {
|
||||
imageCostRatio *= 1.5
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Prompt validation
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Check prompt length
|
||||
if len(imageRequest.Prompt) > common.DalleImagePromptLengthLimitations[imageModel] {
|
||||
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// Number of generated images validation
|
||||
if isWithinRange(imageModel, imageRequest.N) == false {
|
||||
// channel not azure
|
||||
if channelType != common.ChannelTypeAzure {
|
||||
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageModel, apiVersion)
|
||||
}
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body
|
||||
jsonStr, err := json.Marshal(imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
|
||||
|
||||
if userQuota-quota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
token := c.Request.Header.Get("Authorization")
|
||||
if channelType == common.ChannelTypeAzure { // Azure authentication
|
||||
token = strings.TrimPrefix(token, "Bearer ")
|
||||
req.Header.Set("api-key", token)
|
||||
} else {
|
||||
req.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
var textResponse openai.ImageResponse
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return
|
||||
}
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}(c.Request.Context())
|
||||
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = json.Unmarshal(responseBody, &textResponse)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -8,11 +8,22 @@ import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/aiproxy"
|
||||
"one-api/relay/channel/ali"
|
||||
"one-api/relay/channel/anthropic"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/google"
|
||||
"one-api/relay/channel/openai"
|
||||
"one-api/relay/channel/tencent"
|
||||
"one-api/relay/channel/xunfei"
|
||||
"one-api/relay/channel/zhipu"
|
||||
"one-api/relay/constant"
|
||||
"one-api/relay/util"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,59 +35,51 @@ const (
|
||||
APITypeAli
|
||||
APITypeXunfei
|
||||
APITypeAIProxyLibrary
|
||||
APITypeTencent
|
||||
APITypeGemini
|
||||
)
|
||||
|
||||
var httpClient *http.Client
|
||||
var impatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
httpClient = &http.Client{}
|
||||
impatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
func RelayTextHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
channelType := c.GetInt("channel")
|
||||
channelId := c.GetInt("channel_id")
|
||||
tokenId := c.GetInt("token_id")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
var textRequest GeneralOpenAIRequest
|
||||
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
var textRequest openai.GeneralOpenAIRequest
|
||||
err := common.UnmarshalBodyReusable(c, &textRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == RelayModeModerations && textRequest.Model == "" {
|
||||
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
|
||||
return openai.ErrorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest)
|
||||
}
|
||||
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
|
||||
textRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
|
||||
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
|
||||
textRequest.Model = c.Param("model")
|
||||
}
|
||||
// request validation
|
||||
if textRequest.Model == "" {
|
||||
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
switch relayMode {
|
||||
case RelayModeCompletions:
|
||||
case constant.RelayModeCompletions:
|
||||
if textRequest.Prompt == "" {
|
||||
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeChatCompletions:
|
||||
case constant.RelayModeChatCompletions:
|
||||
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
|
||||
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeEmbeddings:
|
||||
case RelayModeModerations:
|
||||
case constant.RelayModeEmbeddings:
|
||||
case constant.RelayModeModerations:
|
||||
if textRequest.Input == "" {
|
||||
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
case RelayModeEdits:
|
||||
case constant.RelayModeEdits:
|
||||
if textRequest.Instruction == "" {
|
||||
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
// map model name
|
||||
@@ -86,7 +89,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if modelMap[textRequest.Model] != "" {
|
||||
textRequest.Model = modelMap[textRequest.Model]
|
||||
@@ -109,22 +112,22 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiType = APITypeXunfei
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
apiType = APITypeAIProxyLibrary
|
||||
case common.ChannelTypeTencent:
|
||||
apiType = APITypeTencent
|
||||
case common.ChannelTypeGemini:
|
||||
apiType = APITypeGemini
|
||||
}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if channelType == common.ChannelTypeAzure {
|
||||
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
apiVersion := util.GetAPIVersion(c)
|
||||
requestURL := strings.Split(requestURL, "?")[0]
|
||||
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
|
||||
baseURL = c.GetString("base_url")
|
||||
@@ -135,7 +138,9 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
model_ = strings.TrimSuffix(model_, "-0301")
|
||||
model_ = strings.TrimSuffix(model_, "-0314")
|
||||
model_ = strings.TrimSuffix(model_, "-0613")
|
||||
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
|
||||
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = util.GetFullRequestURL(baseURL, requestURL, channelType)
|
||||
}
|
||||
case APITypeClaude:
|
||||
fullRequestURL = "https://api.anthropic.com/v1/complete"
|
||||
@@ -148,6 +153,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
|
||||
case "ERNIE-Bot-turbo":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
|
||||
case "ERNIE-Bot-4":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
|
||||
case "BLOOMZ-7B":
|
||||
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
|
||||
case "Embedding-V1":
|
||||
@@ -156,8 +163,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
var err error
|
||||
if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
|
||||
return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
if apiKey, err = baidu.GetAccessToken(apiKey); err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
|
||||
}
|
||||
fullRequestURL += "?access_token=" + apiKey
|
||||
case APITypePaLM:
|
||||
@@ -165,9 +172,20 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if baseURL != "" {
|
||||
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
fullRequestURL += "?key=" + apiKey
|
||||
case APITypeGemini:
|
||||
requestBaseURL := "https://generativelanguage.googleapis.com"
|
||||
if baseURL != "" {
|
||||
requestBaseURL = baseURL
|
||||
}
|
||||
version := "v1"
|
||||
if c.GetString("api_version") != "" {
|
||||
version = c.GetString("api_version")
|
||||
}
|
||||
action := "generateContent"
|
||||
if textRequest.Stream {
|
||||
action = "streamGenerateContent"
|
||||
}
|
||||
fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action)
|
||||
case APITypeZhipu:
|
||||
method := "invoke"
|
||||
if textRequest.Stream {
|
||||
@@ -176,21 +194,23 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
|
||||
case APITypeAli:
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
if relayMode == RelayModeEmbeddings {
|
||||
if relayMode == constant.RelayModeEmbeddings {
|
||||
fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
}
|
||||
case APITypeTencent:
|
||||
fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions"
|
||||
case APITypeAIProxyLibrary:
|
||||
fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
|
||||
}
|
||||
var promptTokens int
|
||||
var completionTokens int
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
case RelayModeCompletions:
|
||||
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case RelayModeModerations:
|
||||
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
|
||||
case constant.RelayModeChatCompletions:
|
||||
promptTokens = openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
case constant.RelayModeCompletions:
|
||||
promptTokens = openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case constant.RelayModeModerations:
|
||||
promptTokens = openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||
}
|
||||
preConsumedTokens := common.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
@@ -202,11 +222,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||||
}
|
||||
err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if userQuota > 100*preConsumedQuota {
|
||||
// in this case, we do not pre-consume quota
|
||||
@@ -214,17 +237,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
preConsumedQuota = 0
|
||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
|
||||
}
|
||||
if consumeQuota && preConsumedQuota > 0 {
|
||||
if preConsumedQuota > 0 {
|
||||
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
@@ -232,62 +255,86 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
switch apiType {
|
||||
case APITypeClaude:
|
||||
claudeRequest := requestOpenAI2Claude(textRequest)
|
||||
claudeRequest := anthropic.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeBaidu:
|
||||
var jsonData []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest)
|
||||
case constant.RelayModeEmbeddings:
|
||||
baiduEmbeddingRequest := baidu.ConvertEmbeddingRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduEmbeddingRequest)
|
||||
default:
|
||||
baiduRequest := requestOpenAI2Baidu(textRequest)
|
||||
baiduRequest := baidu.ConvertRequest(textRequest)
|
||||
jsonData, err = json.Marshal(baiduRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case APITypePaLM:
|
||||
palmRequest := requestOpenAI2PaLM(textRequest)
|
||||
palmRequest := google.ConvertPaLMRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeGemini:
|
||||
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeZhipu:
|
||||
zhipuRequest := requestOpenAI2Zhipu(textRequest)
|
||||
zhipuRequest := zhipu.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(zhipuRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAli:
|
||||
var jsonStr []byte
|
||||
var err error
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
|
||||
case constant.RelayModeEmbeddings:
|
||||
aliEmbeddingRequest := ali.ConvertEmbeddingRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliEmbeddingRequest)
|
||||
default:
|
||||
aliRequest := requestOpenAI2Ali(textRequest)
|
||||
aliRequest := ali.ConvertRequest(textRequest)
|
||||
jsonStr, err = json.Marshal(aliRequest)
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeTencent:
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
appId, secretId, secretKey, err := tencent.ParseConfig(apiKey)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError)
|
||||
}
|
||||
tencentRequest := tencent.ConvertRequest(textRequest)
|
||||
tencentRequest.AppId = appId
|
||||
tencentRequest.SecretId = secretId
|
||||
jsonStr, err := json.Marshal(tencentRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
sign := tencent.GetSign(*tencentRequest, secretKey)
|
||||
c.Request.Header.Set("Authorization", sign)
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case APITypeAIProxyLibrary:
|
||||
aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
|
||||
aiProxyLibraryRequest := aiproxy.ConvertRequest(textRequest)
|
||||
aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
|
||||
jsonStr, err := json.Marshal(aiProxyLibraryRequest)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
}
|
||||
@@ -299,7 +346,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if apiType != APITypeXunfei { // cause xunfei use websocket
|
||||
req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
apiKey := c.Request.Header.Get("Authorization")
|
||||
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
|
||||
@@ -322,30 +369,42 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
req.Header.Set("anthropic-version", anthropicVersion)
|
||||
case APITypeZhipu:
|
||||
token := getZhipuToken(apiKey)
|
||||
token := zhipu.GetToken(apiKey)
|
||||
req.Header.Set("Authorization", token)
|
||||
case APITypeAli:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
if textRequest.Stream {
|
||||
req.Header.Set("X-DashScope-SSE", "enable")
|
||||
}
|
||||
if c.GetString("plugin") != "" {
|
||||
req.Header.Set("X-DashScope-Plugin", c.GetString("plugin"))
|
||||
}
|
||||
case APITypeTencent:
|
||||
req.Header.Set("Authorization", apiKey)
|
||||
case APITypePaLM:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
case APITypeGemini:
|
||||
req.Header.Set("x-goog-api-key", apiKey)
|
||||
default:
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
}
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
|
||||
resp, err = httpClient.Do(req)
|
||||
resp, err = util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
@@ -359,63 +418,60 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
}(c.Request.Context())
|
||||
}
|
||||
return relayErrorHandler(resp)
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
}
|
||||
|
||||
var textResponse TextResponse
|
||||
var textResponse openai.SlimTextResponse
|
||||
tokenName := c.GetString("token_name")
|
||||
|
||||
defer func(ctx context.Context) {
|
||||
// c.Writer.Flush()
|
||||
go func() {
|
||||
if consumeQuota {
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
|
||||
quota = promptTokens + int(float64(completionTokens)*completionRatio)
|
||||
quota = int(float64(quota) * ratio)
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
quota := 0
|
||||
completionRatio := common.GetCompletionRatio(textRequest.Model)
|
||||
promptTokens = textResponse.Usage.PromptTokens
|
||||
completionTokens = textResponse.Usage.CompletionTokens
|
||||
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
|
||||
if ratio != 0 && quota <= 0 {
|
||||
quota = 1
|
||||
}
|
||||
totalTokens := promptTokens + completionTokens
|
||||
if totalTokens == 0 {
|
||||
// in this case, must be some error happened
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
|
||||
}()
|
||||
}(c.Request.Context())
|
||||
switch apiType {
|
||||
case APITypeOpenAI:
|
||||
if isStream {
|
||||
err, responseText := openaiStreamHandler(c, resp, relayMode)
|
||||
err, responseText := openai.StreamHandler(c, resp, relayMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := openaiHandler(c, resp, consumeQuota, promptTokens, textRequest.Model)
|
||||
err, usage := openai.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -426,15 +482,15 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeClaude:
|
||||
if isStream {
|
||||
err, responseText := claudeStreamHandler(c, resp)
|
||||
err, responseText := anthropic.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage := anthropic.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -445,7 +501,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeBaidu:
|
||||
if isStream {
|
||||
err, usage := baiduStreamHandler(c, resp)
|
||||
err, usage := baidu.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -454,13 +510,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = baiduEmbeddingHandler(c, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = baidu.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = baiduHandler(c, resp)
|
||||
err, usage = baidu.Handler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -472,15 +528,34 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypePaLM:
|
||||
if textRequest.Stream { // PaLM2 API does not support stream
|
||||
err, responseText := palmStreamHandler(c, resp)
|
||||
err, responseText := google.PaLMStreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model)
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage := google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeGemini:
|
||||
if textRequest.Stream {
|
||||
err, responseText := google.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -491,7 +566,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeZhipu:
|
||||
if isStream {
|
||||
err, usage := zhipuStreamHandler(c, resp)
|
||||
err, usage := zhipu.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -502,7 +577,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens
|
||||
return nil
|
||||
} else {
|
||||
err, usage := zhipuHandler(c, resp)
|
||||
err, usage := zhipu.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -515,7 +590,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
case APITypeAli:
|
||||
if isStream {
|
||||
err, usage := aliStreamHandler(c, resp)
|
||||
err, usage := ali.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -524,13 +599,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
switch relayMode {
|
||||
case RelayModeEmbeddings:
|
||||
err, usage = aliEmbeddingHandler(c, resp)
|
||||
case constant.RelayModeEmbeddings:
|
||||
err, usage = ali.EmbeddingHandler(c, resp)
|
||||
default:
|
||||
err, usage = aliHandler(c, resp)
|
||||
err, usage = ali.Handler(c, resp)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -545,14 +620,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
splits := strings.Split(auth, "|")
|
||||
if len(splits) != 3 {
|
||||
return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
return openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
|
||||
}
|
||||
var err *OpenAIErrorWithStatusCode
|
||||
var usage *Usage
|
||||
var err *openai.ErrorWithStatusCode
|
||||
var usage *openai.Usage
|
||||
if isStream {
|
||||
err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
err, usage = xunfei.StreamHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
} else {
|
||||
err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
err, usage = xunfei.Handler(c, textRequest, splits[0], splits[1], splits[2])
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -563,7 +638,7 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return nil
|
||||
case APITypeAIProxyLibrary:
|
||||
if isStream {
|
||||
err, usage := aiProxyLibraryStreamHandler(c, resp)
|
||||
err, usage := aiproxy.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -572,7 +647,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
err, usage := aiProxyLibraryHandler(c, resp)
|
||||
err, usage := aiproxy.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if usage != nil {
|
||||
textResponse.Usage = *usage
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case APITypeTencent:
|
||||
if isStream {
|
||||
err, responseText := tencent.StreamHandler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
textResponse.Usage.PromptTokens = promptTokens
|
||||
textResponse.Usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
return nil
|
||||
} else {
|
||||
err, usage := tencent.Handler(c, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -582,6 +676,6 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
return openai.ErrorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
166
relay/util/common.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/relay/channel/openai"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ShouldDisableChannel(err *openai.Error, statusCode int) bool {
|
||||
if !common.AutomaticDisableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if statusCode == http.StatusUnauthorized {
|
||||
return true
|
||||
}
|
||||
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ShouldEnableChannel(err error, openAIErr *openai.Error) bool {
|
||||
if !common.AutomaticEnableChannelEnabled {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if openAIErr != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type GeneralErrorResponse struct {
|
||||
Error openai.Error `json:"error"`
|
||||
Message string `json:"message"`
|
||||
Msg string `json:"msg"`
|
||||
Err string `json:"err"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
Header struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"header"`
|
||||
Response struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
} `json:"response"`
|
||||
}
|
||||
|
||||
func (e GeneralErrorResponse) ToMessage() string {
|
||||
if e.Error.Message != "" {
|
||||
return e.Error.Message
|
||||
}
|
||||
if e.Message != "" {
|
||||
return e.Message
|
||||
}
|
||||
if e.Msg != "" {
|
||||
return e.Msg
|
||||
}
|
||||
if e.Err != "" {
|
||||
return e.Err
|
||||
}
|
||||
if e.ErrorMsg != "" {
|
||||
return e.ErrorMsg
|
||||
}
|
||||
if e.Header.Message != "" {
|
||||
return e.Header.Message
|
||||
}
|
||||
if e.Response.Error.Message != "" {
|
||||
return e.Response.Error.Message
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *openai.ErrorWithStatusCode) {
|
||||
ErrorWithStatusCode = &openai.ErrorWithStatusCode{
|
||||
StatusCode: resp.StatusCode,
|
||||
Error: openai.Error{
|
||||
Message: "",
|
||||
Type: "upstream_error",
|
||||
Code: "bad_response_status_code",
|
||||
Param: strconv.Itoa(resp.StatusCode),
|
||||
},
|
||||
}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var errResponse GeneralErrorResponse
|
||||
err = json.Unmarshal(responseBody, &errResponse)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if errResponse.Error.Message != "" {
|
||||
// OpenAI format error, so we override the default one
|
||||
ErrorWithStatusCode.Error = errResponse.Error
|
||||
} else {
|
||||
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
|
||||
}
|
||||
if ErrorWithStatusCode.Error.Message == "" {
|
||||
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
switch channelType {
|
||||
case common.ChannelTypeOpenAI:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
|
||||
case common.ChannelTypeAzure:
|
||||
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
|
||||
}
|
||||
}
|
||||
return fullRequestURL
|
||||
}
|
||||
|
||||
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
|
||||
// quotaDelta is remaining quota to be consumed
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
// totalQuota is total quota consumed
|
||||
if totalQuota != 0 {
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
|
||||
model.UpdateChannelUsedQuota(channelId, totalQuota)
|
||||
}
|
||||
if totalQuota <= 0 {
|
||||
common.LogError(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
|
||||
}
|
||||
}
|
||||
|
||||
func GetAPIVersion(c *gin.Context) string {
|
||||
query := c.Request.URL.Query()
|
||||
apiVersion := query.Get("api-version")
|
||||
if apiVersion == "" {
|
||||
apiVersion = c.GetString("api_version")
|
||||
}
|
||||
return apiVersion
|
||||
}
|
24
relay/util/init.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
var HTTPClient *http.Client
|
||||
var ImpatientHTTPClient *http.Client
|
||||
|
||||
func init() {
|
||||
if common.RelayTimeout == 0 {
|
||||
HTTPClient = &http.Client{}
|
||||
} else {
|
||||
HTTPClient = &http.Client{
|
||||
Timeout: time.Duration(common.RelayTimeout) * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
ImpatientHTTPClient = &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
}
|
@@ -35,6 +35,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
selfRoute := userRoute.Group("/")
|
||||
selfRoute.Use(middleware.UserAuth())
|
||||
{
|
||||
selfRoute.GET("/dashboard", controller.GetUserDashboard)
|
||||
selfRoute.GET("/self", controller.GetSelf)
|
||||
selfRoute.PUT("/self", controller.UpdateSelf)
|
||||
selfRoute.DELETE("/self", controller.DeleteSelf)
|
||||
@@ -74,6 +75,7 @@ func SetApiRouter(router *gin.Engine) {
|
||||
channelRoute.GET("/update_balance/:id", controller.UpdateChannelBalance)
|
||||
channelRoute.POST("/", controller.AddChannel)
|
||||
channelRoute.PUT("/", controller.UpdateChannel)
|
||||
channelRoute.DELETE("/disabled", controller.DeleteDisabledChannel)
|
||||
channelRoute.DELETE("/:id", controller.DeleteChannel)
|
||||
}
|
||||
tokenRoute := apiRouter.Group("/token")
|
||||
|
@@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
func SetRouter(router *gin.Engine, buildFS embed.FS) {
|
||||
SetApiRouter(router)
|
||||
SetDashboardRouter(router)
|
||||
SetRelayRouter(router)
|
||||
@@ -20,7 +20,7 @@ func SetRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
common.SysLog("FRONTEND_BASE_URL is ignored on master node")
|
||||
}
|
||||
if frontendBaseUrl == "" {
|
||||
SetWebRouter(router, buildFS, indexPage)
|
||||
SetWebRouter(router, buildFS)
|
||||
} else {
|
||||
frontendBaseUrl = strings.TrimSuffix(frontendBaseUrl, "/")
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
|
@@ -17,7 +17,7 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
modelsRouter.GET("/:model", controller.RetrieveModel)
|
||||
}
|
||||
relayV1Router := router.Group("/v1")
|
||||
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute())
|
||||
relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute())
|
||||
{
|
||||
relayV1Router.POST("/completions", controller.Relay)
|
||||
relayV1Router.POST("/chat/completions", controller.Relay)
|
||||
@@ -29,17 +29,44 @@ func SetRelayRouter(router *gin.Engine) {
|
||||
relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
|
||||
relayV1Router.POST("/audio/transcriptions", controller.Relay)
|
||||
relayV1Router.POST("/audio/translations", controller.Relay)
|
||||
relayV1Router.POST("/audio/speech", controller.Relay)
|
||||
relayV1Router.GET("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/files", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/moderations", controller.Relay)
|
||||
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented)
|
||||
relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented)
|
||||
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)
|
||||
}
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package router
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
"github.com/gin-contrib/gzip"
|
||||
"github.com/gin-contrib/static"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -12,17 +13,18 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func SetWebRouter(router *gin.Engine, buildFS embed.FS, indexPage []byte) {
|
||||
func SetWebRouter(router *gin.Engine, buildFS embed.FS) {
|
||||
indexPageData, _ := buildFS.ReadFile(fmt.Sprintf("web/build/%s/index.html", common.Theme))
|
||||
router.Use(gzip.Gzip(gzip.DefaultCompression))
|
||||
router.Use(middleware.GlobalWebRateLimit())
|
||||
router.Use(middleware.Cache())
|
||||
router.Use(static.Serve("/", common.EmbedFolder(buildFS, "web/build")))
|
||||
router.Use(static.Serve("/", common.EmbedFolder(buildFS, fmt.Sprintf("web/build/%s", common.Theme))))
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
if strings.HasPrefix(c.Request.RequestURI, "/v1") || strings.HasPrefix(c.Request.RequestURI, "/api") {
|
||||
controller.RelayNotFound(c)
|
||||
return
|
||||
}
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPage)
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", indexPageData)
|
||||
})
|
||||
}
|
||||
|
@@ -1,21 +1,38 @@
|
||||
# React Template
|
||||
# One API 的前端界面
|
||||
|
||||
## Basic Usages
|
||||
> 每个文件夹代表一个主题,欢迎提交你的主题
|
||||
|
||||
```shell
|
||||
# Runs the app in the development mode
|
||||
npm start
|
||||
## 提交新的主题
|
||||
|
||||
# Builds the app for production to the `build` folder
|
||||
npm run build
|
||||
```
|
||||
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
|
||||
|
||||
If you want to change the default server, please set `REACT_APP_SERVER` environment variables before build,
|
||||
for example: `REACT_APP_SERVER=http://your.domain.com`.
|
||||
1. 在 `web` 文件夹下新建一个文件夹,文件夹名为主题名。
|
||||
2. 把你的主题文件放到这个文件夹下。
|
||||
3. 修改你的 `package.json` 文件,把 `build` 命令改为:`"build": "react-scripts build && mv -f build ../build/default"`,其中 `default` 为你的主题名。
|
||||
4. 修改 `common/constants.go` 中的 `ValidThemes`,把你的主题名称注册进去。
|
||||
5. 修改 `web/THEMES` 文件,这里也需要同步修改。
|
||||
|
||||
Before you start editing, make sure your `Actions on Save` options have `Optimize imports` & `Run Prettier` enabled.
|
||||
## 主题列表
|
||||
|
||||
## Reference
|
||||
### 主题:default
|
||||
|
||||
1. https://github.com/OIerDb-ng/OIerDb
|
||||
2. https://github.com/cornflourblue/react-hooks-redux-registration-login-example
|
||||
默认主题,由 [JustSong](https://github.com/songquanpeng) 开发。
|
||||
|
||||
预览:
|
||||
|||
|
||||
|:---:|:---:|
|
||||
|
||||
### 主题:berry
|
||||
|
||||
由 [MartialBE](https://github.com/MartialBE) 开发。
|
||||
|
||||
预览:
|
||||
|||
|
||||
|:---:|:---:|
|
||||
|||
|
||||
|||
|
||||
|||
|
||||
|
||||
#### 开发说明
|
||||
|
||||
请查看 [web/berry/README.md](https://github.com/songquanpeng/one-api/tree/main/web/berry/README.md)
|
||||
|
2
web/THEMES
Normal file
@@ -0,0 +1,2 @@
|
||||
default
|
||||
berry
|
0
web/.gitignore → web/berry/.gitignore
vendored
61
web/berry/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# One API 前端界面
|
||||
|
||||
这个项目是 One API 的前端界面,它基于 [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template) 进行开发。
|
||||
|
||||
## 使用的开源项目
|
||||
|
||||
使用了以下开源项目作为我们项目的一部分:
|
||||
|
||||
- [Berry Free React Admin Template](https://github.com/codedthemes/berry-free-react-admin-template)
|
||||
- [minimal-ui-kit](minimal-ui-kit)
|
||||
|
||||
## 开发说明
|
||||
|
||||
当添加新的渠道时,需要修改以下地方:
|
||||
|
||||
1. `web/berry/src/constants/ChannelConstants.js`
|
||||
|
||||
在该文件中的 `CHANNEL_OPTIONS` 添加新的渠道
|
||||
|
||||
```js
|
||||
export const CHANNEL_OPTIONS = {
|
||||
//key 为渠道ID
|
||||
1: {
|
||||
key: 1, // 渠道ID
|
||||
text: "OpenAI", // 渠道名称
|
||||
value: 1, // 渠道ID
|
||||
color: "primary", // 渠道列表显示的颜色
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
2. `web/berry/src/views/Channel/type/Config.js`
|
||||
|
||||
在该文件中的`typeConfig`添加新的渠道配置, 如果无需配置,可以不添加
|
||||
|
||||
```js
|
||||
const typeConfig = {
|
||||
// key 为渠道ID
|
||||
3: {
|
||||
inputLabel: {
|
||||
// 输入框名称 配置
|
||||
// 对应的字段名称
|
||||
base_url: "AZURE_OPENAI_ENDPOINT",
|
||||
other: "默认 API 版本",
|
||||
},
|
||||
prompt: {
|
||||
// 输入框提示 配置
|
||||
// 对应的字段名称
|
||||
base_url: "请填写AZURE_OPENAI_ENDPOINT",
|
||||
|
||||
// 注意:通过判断 `other` 是否有值来判断是否需要显示 `other` 输入框, 默认是没有值的
|
||||
other: "请输入默认API版本,例如:2023-06-01-preview",
|
||||
},
|
||||
modelGroup: "openai", // 模型组名称,这个值是给 填入渠道支持模型 按钮使用的。 填入渠道支持模型 按钮会根据这个值来获取模型组,如果填写默认是 openai
|
||||
},
|
||||
};
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
本项目中使用的代码遵循 MIT 许可证。
|
9
web/berry/jsconfig.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "esnext",
|
||||
"module": "commonjs",
|
||||
"baseUrl": "src"
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules"]
|
||||
}
|
84
web/berry/package.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"name": "one_api_web",
|
||||
"version": "1.0.0",
|
||||
"proxy": "http://127.0.0.1:3000",
|
||||
"private": true,
|
||||
"homepage": "",
|
||||
"dependencies": {
|
||||
"@emotion/cache": "^11.9.3",
|
||||
"@emotion/react": "^11.9.3",
|
||||
"@emotion/styled": "^11.9.3",
|
||||
"@mui/icons-material": "^5.8.4",
|
||||
"@mui/lab": "^5.0.0-alpha.88",
|
||||
"@mui/material": "^5.8.6",
|
||||
"@mui/system": "^5.8.6",
|
||||
"@mui/utils": "^5.8.6",
|
||||
"@mui/x-date-pickers": "^6.18.5",
|
||||
"@tabler/icons-react": "^2.44.0",
|
||||
"apexcharts": "^3.35.3",
|
||||
"axios": "^0.27.2",
|
||||
"dayjs": "^1.11.10",
|
||||
"formik": "^2.2.9",
|
||||
"framer-motion": "^6.3.16",
|
||||
"history": "^5.3.0",
|
||||
"marked": "^4.1.1",
|
||||
"material-ui-popup-state": "^4.0.1",
|
||||
"notistack": "^3.0.1",
|
||||
"prop-types": "^15.8.1",
|
||||
"react": "^18.2.0",
|
||||
"react-apexcharts": "^1.4.0",
|
||||
"react-device-detect": "^2.2.2",
|
||||
"react-dom": "^18.2.0",
|
||||
"react-perfect-scrollbar": "^1.5.8",
|
||||
"react-redux": "^8.0.2",
|
||||
"react-router": "6.3.0",
|
||||
"react-router-dom": "6.3.0",
|
||||
"react-scripts": "^5.0.1",
|
||||
"react-turnstile": "^1.1.2",
|
||||
"redux": "^4.2.0",
|
||||
"yup": "^0.32.11"
|
||||
},
|
||||
"scripts": {
|
||||
"start": "react-scripts start",
|
||||
"build": "react-scripts build && mv -f build ../build/berry",
|
||||
"test": "react-scripts test",
|
||||
"eject": "react-scripts eject"
|
||||
},
|
||||
"eslintConfig": {
|
||||
"extends": [
|
||||
"react-app"
|
||||
]
|
||||
},
|
||||
"babel": {
|
||||
"presets": [
|
||||
"@babel/preset-react"
|
||||
]
|
||||
},
|
||||
"browserslist": {
|
||||
"production": [
|
||||
"defaults",
|
||||
"not IE 11"
|
||||
],
|
||||
"development": [
|
||||
"last 1 chrome version",
|
||||
"last 1 firefox version",
|
||||
"last 1 safari version"
|
||||
]
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/core": "^7.21.4",
|
||||
"@babel/eslint-parser": "^7.21.3",
|
||||
"eslint": "^8.38.0",
|
||||
"eslint-config-prettier": "^8.8.0",
|
||||
"eslint-config-react-app": "^7.0.1",
|
||||
"eslint-plugin-flowtype": "^8.0.3",
|
||||
"eslint-plugin-import": "^2.27.5",
|
||||
"eslint-plugin-jsx-a11y": "^6.7.1",
|
||||
"eslint-plugin-prettier": "^4.2.1",
|
||||
"eslint-plugin-react": "^7.32.2",
|
||||
"eslint-plugin-react-hooks": "^4.6.0",
|
||||
"immutable": "^4.3.0",
|
||||
"prettier": "^2.8.7",
|
||||
"sass": "^1.53.0"
|
||||
}
|
||||
}
|
BIN
web/berry/public/favicon.ico
Normal file
After Width: | Height: | Size: 40 KiB |
26
web/berry/public/index.html
Normal file
@@ -0,0 +1,26 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<title>One API</title>
|
||||
<link rel="icon" href="/favicon.ico" />
|
||||
<!-- Meta Tags-->
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<meta name="theme-color" content="#2296f3" />
|
||||
<meta
|
||||
name="description"
|
||||
content="OpenAI 接口聚合管理,支持多种渠道包括 Azure,可用于二次分发管理 key,仅单可执行文件,已打包好 Docker 镜像,一键部署,开箱即用"
|
||||
/>
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" />
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Poppins:wght@400;500;600;700&family=Roboto:wght@400;500;700&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<noscript>You need to enable JavaScript to run this app.</noscript>
|
||||
<div id="root"></div>
|
||||
|
||||
</body>
|
||||
</html>
|
43
web/berry/src/App.js
Normal file
@@ -0,0 +1,43 @@
|
||||
import { useSelector } from 'react-redux';
|
||||
|
||||
import { ThemeProvider } from '@mui/material/styles';
|
||||
import { CssBaseline, StyledEngineProvider } from '@mui/material';
|
||||
|
||||
// routing
|
||||
import Routes from 'routes';
|
||||
|
||||
// defaultTheme
|
||||
import themes from 'themes';
|
||||
|
||||
// project imports
|
||||
import NavigationScroll from 'layout/NavigationScroll';
|
||||
|
||||
// auth
|
||||
import UserProvider from 'contexts/UserContext';
|
||||
import StatusProvider from 'contexts/StatusContext';
|
||||
import { SnackbarProvider } from 'notistack';
|
||||
|
||||
// ==============================|| APP ||============================== //
|
||||
|
||||
const App = () => {
|
||||
const customization = useSelector((state) => state.customization);
|
||||
|
||||
return (
|
||||
<StyledEngineProvider injectFirst>
|
||||
<ThemeProvider theme={themes(customization)}>
|
||||
<CssBaseline />
|
||||
<NavigationScroll>
|
||||
<SnackbarProvider autoHideDuration={5000} maxSnack={3} anchorOrigin={{ vertical: 'top', horizontal: 'right' }}>
|
||||
<UserProvider>
|
||||
<StatusProvider>
|
||||
<Routes />
|
||||
</StatusProvider>
|
||||
</UserProvider>
|
||||
</SnackbarProvider>
|
||||
</NavigationScroll>
|
||||
</ThemeProvider>
|
||||
</StyledEngineProvider>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
40
web/berry/src/assets/images/404.svg
Normal file
@@ -0,0 +1,40 @@
|
||||
<svg width="480" height="360" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd"
|
||||
d="M0 198.781c0 41.457 14.945 79.235 39.539 107.785 28.214 32.765 69.128 53.365 114.734 53.434a148.458 148.458 0 0056.495-11.036c9.051-3.699 19.182-3.274 27.948 1.107a75.774 75.774 0 0033.957 8.011c5.023 0 9.942-.495 14.7-1.434 13.581-2.67 25.94-8.99 36.089-17.94 6.379-5.627 14.548-8.456 22.898-8.446h.142c27.589 0 53.215-8.732 74.492-23.696 19.021-13.36 34.554-31.696 44.904-53.225C474.92 234.581 480 213.388 480 190.958c0-76.931-59.774-139.305-133.498-139.305-7.516 0-14.88.663-22.063 1.899C305.418 21.42 271.355 0 232.498 0a103.647 103.647 0 00-45.879 10.661c-13.24 6.487-25.011 15.705-34.641 26.939-32.697.544-62.93 11.69-87.675 30.291C25.351 97.155 0 144.882 0 198.781z"
|
||||
fill="url(#prefix__paint0_linear)" opacity=".2" />
|
||||
<g filter="url(#prefix__filter0_d)">
|
||||
<circle opacity=".15" cx="182.109" cy="97.623" r="44.623" fill="#FFC107" />
|
||||
<circle cx="182.109" cy="97.623" r="23.406" fill="url(#prefix__paint1_linear)" />
|
||||
<path fill-rule="evenodd" clip-rule="evenodd"
|
||||
d="M244.878 306.611c34.56 0 62.575-28.016 62.575-62.575 0-34.56-28.015-62.576-62.575-62.576-34.559 0-62.575 28.016-62.575 62.576 0 34.559 28.016 62.575 62.575 62.575zm0-23.186c21.754 0 39.389-17.635 39.389-39.389 0-21.755-17.635-39.39-39.389-39.39s-39.389 17.635-39.389 39.39c0 21.754 17.635 39.389 39.389 39.389z"
|
||||
fill="#061B64" />
|
||||
<path fill-rule="evenodd" clip-rule="evenodd"
|
||||
d="M174.965 264.592c0-4.133-1.492-5.625-5.637-5.625h-11.373v-66.611c0-4.476-1.492-5.637-5.638-5.637h-9.172a9.866 9.866 0 00-7.948 3.974l-55.03 68.274a11.006 11.006 0 00-1.957 6.787v5.968c0 4.145 1.492 5.637 5.625 5.637h54.676v21.707c0 4.133 1.492 5.625 5.625 5.625h8.12c4.146 0 5.638-1.492 5.638-5.625v-21.707h11.434c4.414 0 5.637-1.492 5.637-5.637v-7.13zm-72.42-5.625l35.966-44.415v44.415h-35.966zM411.607 264.592c0-4.133-1.492-5.625-5.638-5.625h-11.422v-66.611c0-4.476-1.492-5.637-5.637-5.637h-9.111a9.87 9.87 0 00-7.949 3.974l-55.03 68.274a11.011 11.011 0 00-1.981 6.787v5.968c0 4.145 1.492 5.637 5.626 5.637h54.687v21.707c0 4.133 1.492 5.625 5.626 5.625h8.12c4.145 0 5.637-1.492 5.637-5.625v-21.707h11.434c4.476 0 5.638-1.492 5.638-5.637v-7.13zm-72.42-5.625l35.965-44.415v44.415h-35.965z"
|
||||
fill="#2065D1" />
|
||||
<path opacity=".24"
|
||||
d="M425.621 117.222a8.267 8.267 0 00-9.599-8.157 11.129 11.129 0 00-9.784-5.87h-.403a13.23 13.23 0 00-20.365-14.078 13.23 13.23 0 00-5.316 14.078h-.403a11.153 11.153 0 100 22.293h38.68v-.073a8.279 8.279 0 007.19-8.193zM104.258 199.045a7.093 7.093 0 00-7.093-7.092c-.381.007-.761.039-1.138.097a9.552 9.552 0 00-8.425-5.026h-.343a11.348 11.348 0 10-22.012 0h-.342a9.564 9.564 0 100 19.114h33.177v-.061a7.107 7.107 0 006.176-7.032z"
|
||||
fill="#2065D1" />
|
||||
</g>
|
||||
<defs>
|
||||
<linearGradient id="prefix__paint0_linear" x1="328.81" y1="424.032" x2="505.393" y2="26.048"
|
||||
gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#2065D1" />
|
||||
<stop offset="1" stop-color="#2065D1" stop-opacity=".01" />
|
||||
</linearGradient>
|
||||
<linearGradient id="prefix__paint1_linear" x1="135.297" y1="97.623" x2="182.109" y2="144.436"
|
||||
gradientUnits="userSpaceOnUse">
|
||||
<stop stop-color="#FFE16A" />
|
||||
<stop offset="1" stop-color="#B78103" />
|
||||
</linearGradient>
|
||||
<filter id="prefix__filter0_d" x="51" y="49" width="394.621" height="277.611" filterUnits="userSpaceOnUse"
|
||||
color-interpolation-filters="sRGB">
|
||||
<feFlood flood-opacity="0" result="BackgroundImageFix" />
|
||||
<feColorMatrix in="SourceAlpha" values="0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 127 0" />
|
||||
<feOffset dx="8" dy="8" />
|
||||
<feGaussianBlur stdDeviation="6" />
|
||||
<feColorMatrix values="0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.12 0" />
|
||||
<feBlend in2="BackgroundImageFix" result="effect1_dropShadow" />
|
||||
<feBlend in="SourceGraphic" in2="effect1_dropShadow" result="shape" />
|
||||
</filter>
|
||||
</defs>
|
||||
</svg>
|
After Width: | Height: | Size: 3.9 KiB |
65
web/berry/src/assets/images/auth/auth-blue-card.svg
Normal file
After Width: | Height: | Size: 16 KiB |
39
web/berry/src/assets/images/auth/auth-pattern-dark.svg
Normal file
@@ -0,0 +1,39 @@
|
||||
<svg width="670" height="903" viewBox="0 0 670 903" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<mask id="mask0" mask-type="alpha" maskUnits="userSpaceOnUse" x="0" y="0" width="670" height="903">
|
||||
<g opacity="0.2">
|
||||
<path d="M0 0H670V903H0V0Z" fill="white"/>
|
||||
</g>
|
||||
</mask>
|
||||
<g mask="url(#mask0)">
|
||||
<path d="M2030.91 374.849L426.331 1300.78" stroke="#8492C4"/>
|
||||
<path d="M426.409 -527.071L2030.72 399.311" stroke="#8492C4"/>
|
||||
<path d="M1919.22 310.39L314.731 1236.47" stroke="#8492C4"/>
|
||||
<path d="M314.731 -462.612L1919.22 463.467" stroke="#8492C4"/>
|
||||
<path d="M1807.54 245.932L203.055 1172.01" stroke="#8492C4"/>
|
||||
<path d="M203.052 -398.154L1807.54 527.925" stroke="#8492C4"/>
|
||||
<path d="M1695.87 181.473L91.3788 1107.55" stroke="#8492C4"/>
|
||||
<path d="M91.3744 -333.695L1695.86 592.384" stroke="#8492C4"/>
|
||||
<path d="M1584.19 117.014L-20.3012 1043.09" stroke="#8492C4"/>
|
||||
<path d="M-20.3044 -269.237L1584.19 656.843" stroke="#8492C4"/>
|
||||
<path d="M1472.51 52.5562L-131.98 978.636" stroke="#8492C4"/>
|
||||
<path d="M-131.983 -204.778L1472.51 721.301" stroke="#8492C4"/>
|
||||
<path d="M1360.83 -11.9023L-243.658 914.177" stroke="#8492C4"/>
|
||||
<path d="M-243.662 -140.319L1360.83 785.76" stroke="#8492C4"/>
|
||||
<path d="M1249.15 -76.3613L-355.336 849.718" stroke="#8492C4"/>
|
||||
<path d="M-355.341 -75.8608L1249.15 850.219" stroke="#8492C4"/>
|
||||
<path d="M1137.48 -140.819L-467.014 785.26" stroke="#8492C4"/>
|
||||
<path d="M-467.017 -11.4023L1137.47 914.677" stroke="#8492C4"/>
|
||||
<path d="M1025.8 -205.278L-578.692 720.801" stroke="#8492C4"/>
|
||||
<path d="M-578.693 53.0562L1025.8 979.136" stroke="#8492C4"/>
|
||||
<path d="M914.119 -269.736L-690.371 656.343" stroke="#8492C4"/>
|
||||
<path d="M-690.379 117.515L914.111 1043.59" stroke="#8492C4"/>
|
||||
<path d="M802.441 -334.195L-802.052 591.887" stroke="#8492C4"/>
|
||||
<path d="M-802.055 181.974L802.435 1108.05" stroke="#8492C4"/>
|
||||
<path d="M690.762 -398.654L-913.728 527.426" stroke="#8492C4"/>
|
||||
<path d="M-913.731 246.432L690.759 1172.51" stroke="#8492C4"/>
|
||||
<path d="M579.084 -463.112L-1025.41 462.967" stroke="#8492C4"/>
|
||||
<path d="M-1025.41 310.891L579.083 1236.97" stroke="#8492C4"/>
|
||||
<path d="M467.406 -527.571L-1136.91 398.811" stroke="#8492C4"/>
|
||||
<path d="M-1137.09 375.35L467.397 1301.43" stroke="#8492C4"/>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 2.2 KiB |
39
web/berry/src/assets/images/auth/auth-pattern.svg
Normal file
@@ -0,0 +1,39 @@
|
||||
<svg width="670" height="903" viewBox="0 0 670 903" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<mask id="mask0" mask-type="alpha" maskUnits="userSpaceOnUse" x="0" y="0" width="670" height="903">
|
||||
<g opacity="0.2">
|
||||
<path d="M0 0H670V903H0V0Z" fill="white"/>
|
||||
</g>
|
||||
</mask>
|
||||
<g mask="url(#mask0)">
|
||||
<path d="M2030.91 374.849L426.331 1300.78" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M426.409 -527.071L2030.72 399.311" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1919.22 310.39L314.731 1236.47" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M314.731 -462.612L1919.22 463.467" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1807.54 245.932L203.055 1172.01" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M203.052 -398.154L1807.54 527.925" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1695.87 181.473L91.3788 1107.55" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M91.3744 -333.695L1695.86 592.384" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1584.19 117.014L-20.3012 1043.09" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-20.3044 -269.237L1584.19 656.843" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1472.51 52.5562L-131.98 978.636" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-131.983 -204.778L1472.51 721.301" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1360.83 -11.9023L-243.658 914.177" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-243.662 -140.319L1360.83 785.76" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1249.15 -76.3613L-355.336 849.718" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-355.341 -75.8608L1249.15 850.219" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1137.48 -140.819L-467.014 785.26" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-467.017 -11.4023L1137.47 914.677" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M1025.8 -205.278L-578.692 720.801" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-578.693 53.0562L1025.8 979.136" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M914.119 -269.736L-690.371 656.343" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-690.379 117.515L914.111 1043.59" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M802.441 -334.195L-802.052 591.887" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-802.055 181.974L802.435 1108.05" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M690.762 -398.654L-913.728 527.426" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-913.731 246.432L690.759 1172.51" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M579.084 -463.112L-1025.41 462.967" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-1025.41 310.891L579.083 1236.97" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M467.406 -527.571L-1136.91 398.811" stroke="rgba(0,0,0,0.30)"/>
|
||||
<path d="M-1137.09 375.35L467.397 1301.43" stroke="rgba(0,0,0,0.30)"/>
|
||||
</g>
|
||||
</svg>
|
After Width: | Height: | Size: 2.4 KiB |
69
web/berry/src/assets/images/auth/auth-purple-card.svg
Normal file
After Width: | Height: | Size: 22 KiB |
After Width: | Height: | Size: 272 KiB |
40
web/berry/src/assets/images/auth/auth-signup-white-card.svg
Normal file
After Width: | Height: | Size: 15 KiB |
5
web/berry/src/assets/images/icons/earning.svg
Normal file
@@ -0,0 +1,5 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M19 9H9C7.89543 9 7 9.89543 7 11V17C7 18.1046 7.89543 19 9 19H19C20.1046 19 21 18.1046 21 17V11C21 9.89543 20.1046 9 19 9Z" stroke="white" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M14 16C15.1046 16 16 15.1046 16 14C16 12.8954 15.1046 12 14 12C12.8954 12 12 12.8954 12 14C12 15.1046 12.8954 16 14 16Z" fill="#90CAF9"/>
|
||||
<path d="M17 9V7C17 6.46957 16.7893 5.96086 16.4142 5.58579C16.0391 5.21071 15.5304 5 15 5H5C4.46957 5 3.96086 5.21071 3.58579 5.58579C3.21071 5.96086 3 6.46957 3 7V13C3 13.5304 3.21071 14.0391 3.58579 14.4142C3.96086 14.7893 4.46957 15 5 15H7" stroke="white" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
After Width: | Height: | Size: 794 B |
1
web/berry/src/assets/images/icons/github.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg t="1702350903010" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="4215" width="32" height="32"><path d="M512 85.333333C276.266667 85.333333 85.333333 276.266667 85.333333 512a426.410667 426.410667 0 0 0 291.754667 404.821333c21.333333 3.712 29.312-9.088 29.312-20.309333 0-10.112-0.554667-43.690667-0.554667-79.445333-107.178667 19.754667-134.912-26.112-143.445333-50.133334-4.821333-12.288-25.6-50.133333-43.733333-60.288-14.933333-7.978667-36.266667-27.733333-0.554667-28.245333 33.621333-0.554667 57.6 30.933333 65.621333 43.733333 38.4 64.512 99.754667 46.378667 124.245334 35.2 3.754667-27.733333 14.933333-46.378667 27.221333-57.045333-94.933333-10.666667-194.133333-47.488-194.133333-210.688 0-46.421333 16.512-84.778667 43.733333-114.688-4.266667-10.666667-19.2-54.4 4.266667-113.066667 0 0 35.712-11.178667 117.333333 43.776a395.946667 395.946667 0 0 1 106.666667-14.421333c36.266667 0 72.533333 4.778667 106.666666 14.378667 81.578667-55.466667 117.333333-43.690667 117.333334-43.690667 23.466667 58.666667 8.533333 102.4 4.266666 113.066667 27.178667 29.866667 43.733333 67.712 43.733334 114.645333 0 163.754667-99.712 200.021333-194.645334 210.688 15.445333 13.312 28.8 38.912 28.8 78.933333 0 57.045333-0.554667 102.912-0.554666 117.333334 0 11.178667 8.021333 24.490667 29.354666 20.224A427.349333 427.349333 0 0 0 938.666667 512c0-235.733333-190.933333-426.666667-426.666667-426.666667z" fill="#000000" p-id="4216"></path></svg>
|
After Width: | Height: | Size: 1.5 KiB |
1
web/berry/src/assets/images/icons/shape-avatar.svg
Normal file
@@ -0,0 +1 @@
|
||||
<svg height="62" viewBox="0 0 144 62" width="144" xmlns="http://www.w3.org/2000/svg"><path d="m111.34 23.88c-10.62-10.46-18.5-23.88-38.74-23.88h-1.2c-20.24 0-28.12 13.42-38.74 23.88-7.72 9.64-19.44 11.74-32.66 12.12v26h144v-26c-13.22-.38-24.94-2.48-32.66-12.12z" fill="#fff" fill-rule="evenodd"/></svg>
|
After Width: | Height: | Size: 302 B |
6
web/berry/src/assets/images/icons/social-google.svg
Normal file
@@ -0,0 +1,6 @@
|
||||
<svg width="22" height="22" viewBox="0 0 22 22" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.06129 13.2253L4.31871 15.9975L1.60458 16.0549C0.793457 14.5504 0.333374 12.8292 0.333374 11C0.333374 9.23119 0.763541 7.56319 1.52604 6.09448H1.52662L3.94296 6.53748L5.00146 8.93932C4.77992 9.58519 4.65917 10.2785 4.65917 11C4.65925 11.783 4.80108 12.5332 5.06129 13.2253Z" fill="#FBBB00"/>
|
||||
<path d="M21.4804 9.00732C21.6029 9.65257 21.6668 10.3189 21.6668 11C21.6668 11.7637 21.5865 12.5086 21.4335 13.2271C20.9143 15.6722 19.5575 17.8073 17.678 19.3182L17.6774 19.3177L14.6339 19.1624L14.2031 16.4734C15.4503 15.742 16.425 14.5974 16.9384 13.2271H11.2346V9.00732H17.0216H21.4804Z" fill="#518EF8"/>
|
||||
<path d="M17.6772 19.3176L17.6777 19.3182C15.8498 20.7875 13.5277 21.6666 11 21.6666C6.93783 21.6666 3.40612 19.3962 1.60449 16.0549L5.0612 13.2253C5.96199 15.6294 8.28112 17.3408 11 17.3408C12.1686 17.3408 13.2634 17.0249 14.2029 16.4734L17.6772 19.3176Z" fill="#28B446"/>
|
||||
<path d="M17.8085 2.78892L14.353 5.61792C13.3807 5.01017 12.2313 4.65908 11 4.65908C8.21963 4.65908 5.85713 6.44896 5.00146 8.93925L1.52658 6.09442H1.526C3.30125 2.67171 6.8775 0.333252 11 0.333252C13.5881 0.333252 15.9612 1.25517 17.8085 2.78892Z" fill="#F14336"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.2 KiB |