Compare commits

..

35 Commits

Author SHA1 Message Date
ckt1031
2c62053933 feat: better translation 2023-07-28 15:31:31 +08:00
ckt1031
442673a4be feat(ci): add english binary file 2023-07-28 14:59:34 +08:00
ckt1031
bb17e09818 Merge branch 'support-google-oauth' into refactor-main 2023-07-28 14:57:24 +08:00
ckt1031
80288492f9 Merge branch 'support-discord-oauth' into refactor-main 2023-07-28 14:57:11 +08:00
ckt1031
d89e13dac3 fix: use server address as redirect url 2023-07-28 14:56:17 +08:00
ckt1031
0c70181f81 fix: use server address as redirect url 2023-07-28 14:55:40 +08:00
ckt
3893d4bdf1 Merge branch 'songquanpeng:main' into refactor-main 2023-07-28 00:31:21 +08:00
ckt1031
4aa9f97017 fix: bump pgx 2023-07-27 14:55:45 +08:00
ckt1031
ab041d94e2 fix: go deps 2023-07-27 14:55:14 +08:00
ckt1031
18655fed08 Merge branch 'support-postgres-sql' into refactor-main 2023-07-27 14:51:23 +08:00
ckt1031
bba49c959e feat: support postgres 2023-07-27 14:47:45 +08:00
ckt1031
22f4419b85 fix: reject for wrong status code 2023-07-27 11:10:32 +08:00
ckt1031
c7c3b9d326 fix: backstick for postgres 2023-07-27 11:07:08 +08:00
ckt1031
0d6163a9fb feat: support PG-SQL 2023-07-27 00:10:08 +08:00
ckt1031
7419ca511e chore: use own docker image 2023-07-26 11:45:31 +08:00
ckt
ccd7a99b68 Merge branch 'songquanpeng:main' into refactor-main 2023-07-25 22:38:59 +08:00
ckt1031
759423d69e Merge branch 'support-discord-oauth' into refactor-main 2023-07-25 12:22:12 +08:00
ckt1031
04110d4b01 Merge branch 'support-google-oauth' into refactor-main 2023-07-25 12:22:04 +08:00
ckt1031
d6b2131720 fix: discord 2023-07-25 12:21:37 +08:00
ckt1031
438daea433 fix: google id 2023-07-25 12:14:41 +08:00
ckt1031
c6c070b8bd Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 22:05:54 +08:00
ckt1031
13b3bfee2a fix: channel issue 2023-07-24 22:05:21 +08:00
ckt1031
2b42b4f364 feat: support chatgpt next web 2023-07-24 21:49:04 +08:00
ckt1031
f37e41eb1d Merge branch 'support-google-oauth' into refactor-main 2023-07-24 21:31:45 +08:00
ckt1031
c144c64fff feat: support Google OAuth 2023-07-24 20:09:52 +08:00
ckt1031
8956e2fd60 Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 18:12:53 +08:00
ckt1031
30187cebe8 fix: use int instead of bool 2023-07-24 18:12:16 +08:00
ckt1031
00d3a78bef Merge branch 'channel-stream-mode' into refactor-main 2023-07-24 15:35:16 +08:00
ckt1031
a588241515 feat: allow toggling stream mode of channels 2023-07-24 15:30:08 +08:00
ckt1031
546f9e1db5 Merge branch 'support-discord-oauth' into refactor-main 2023-07-24 12:22:14 +08:00
ckt1031
4908a9eddc feat: add Discord OAuth 2023-07-24 12:20:52 +08:00
ckt1031
15cdaee762 feat: bump go deps 2023-07-24 11:38:09 +08:00
ckt1031
395ee121ed feat: bump all node deps 2023-07-24 11:27:46 +08:00
ckt1031
4cea6279ab fix: install @babel/plugin-proposal-private-property-in-object 2023-07-24 11:19:55 +08:00
ckt1031
f50683e75f feat: add turnstile in login form 2023-07-24 10:54:04 +08:00
365 changed files with 4493 additions and 23244 deletions

View File

@@ -38,7 +38,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
justsong/one-api-en ckt1031/one-api-en
- name: Build and push Docker images - name: Build and push Docker images
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3

View File

@@ -42,7 +42,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
justsong/one-api ckt1031/one-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

View File

@@ -49,7 +49,7 @@ jobs:
uses: docker/metadata-action@v4 uses: docker/metadata-action@v4
with: with:
images: | images: |
justsong/one-api ckt1031/one-api
ghcr.io/${{ github.repository }} ghcr.io/${{ github.repository }}
- name: Build and push Docker images - name: Build and push Docker images

59
.github/workflows/linux-release-en.yml vendored Normal file
View File

@@ -0,0 +1,59 @@
name: Linux Release (English)
permissions:
contents: write
on:
push:
tags:
- "*"
- "!*-alpha*"
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Translate
run: |
python ./i18n/translate.py --repository_path . --json_file_path ./i18n/en.json
- uses: actions/setup-node@v3
with:
node-version: 16
- name: Build Frontend
env:
CI: ""
run: |
cd web
npm install
REACT_APP_VERSION=$(git describe --tags) npm run build
cd ..
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: ">=1.18.0"
- name: Build Backend (amd64)
run: |
go mod download
go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-en
- name: Build Backend (arm64)
run: |
sudo apt-get update
sudo apt-get install gcc-aarch64-linux-gnu
CC=aarch64-linux-gnu-gcc CGO_ENABLED=1 GOOS=linux GOARCH=arm64 go build -ldflags "-s -w -X 'one-api/common.Version=$(git describe --tags)' -extldflags '-static'" -o one-api-arm64-en
- name: Release
uses: softprops/action-gh-release@v1
if: startsWith(github.ref, 'refs/tags/')
with:
files: |
one-api-en
one-api-arm64-en
draft: true
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -7,11 +7,6 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -23,13 +18,13 @@ jobs:
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
- name: Build Frontend (theme default) - name: Build Frontend
env: env:
CI: "" CI: ""
run: | run: |
cd web cd web
git describe --tags > VERSION npm install
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh REACT_APP_VERSION=$(git describe --tags) npm run build
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3

View File

@@ -7,11 +7,6 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: macos-latest runs-on: macos-latest
@@ -23,13 +18,13 @@ jobs:
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
- name: Build Frontend (theme default) - name: Build Frontend
env: env:
CI: "" CI: ""
run: | run: |
cd web cd web
git describe --tags > VERSION npm install
REACT_APP_VERSION=$(git describe --tags) chmod u+x ./build.sh && ./build.sh REACT_APP_VERSION=$(git describe --tags) npm run build
cd .. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3

View File

@@ -7,11 +7,6 @@ on:
tags: tags:
- '*' - '*'
- '!*-alpha*' - '!*-alpha*'
workflow_dispatch:
inputs:
name:
description: 'reason'
required: false
jobs: jobs:
release: release:
runs-on: windows-latest runs-on: windows-latest
@@ -26,14 +21,14 @@ jobs:
- uses: actions/setup-node@v3 - uses: actions/setup-node@v3
with: with:
node-version: 16 node-version: 16
- name: Build Frontend (theme default) - name: Build Frontend
env: env:
CI: "" CI: ""
run: | run: |
cd web/default cd web
npm install npm install
REACT_APP_VERSION=$(git describe --tags) npm run build REACT_APP_VERSION=$(git describe --tags) npm run build
cd ../.. cd ..
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:

5
.gitignore vendored
View File

@@ -4,7 +4,4 @@ upload
*.exe *.exe
*.db *.db
build build
*.db-journal *.db-journal
logs
data
/web/node_modules

View File

@@ -1,16 +1,10 @@
FROM node:16 as builder FROM node:16 as builder
WORKDIR /web WORKDIR /build
COPY ./VERSION .
COPY ./web . COPY ./web .
COPY ./VERSION .
WORKDIR /web/default
RUN npm install RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build RUN REACT_APP_VERSION=$(cat VERSION) npm run build
WORKDIR /web/berry
RUN npm install
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
FROM golang AS builder2 FROM golang AS builder2
@@ -19,10 +13,9 @@ ENV GO111MODULE=on \
GOOS=linux GOOS=linux
WORKDIR /build WORKDIR /build
ADD go.mod go.sum ./
RUN go mod download
COPY . . COPY . .
COPY --from=builder /web/build ./web/build COPY --from=builder /build/build ./web/build
RUN go mod download
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
FROM alpine FROM alpine
@@ -35,4 +28,4 @@ RUN apk update \
COPY --from=builder2 /build/one-api / COPY --from=builder2 /build/one-api /
EXPOSE 3000 EXPOSE 3000
WORKDIR /data WORKDIR /data
ENTRYPOINT ["/one-api"] ENTRYPOINT ["/one-api"]

View File

@@ -1,9 +1,9 @@
<p align="right"> <p align="right">
<a href="./README.md">中文</a> | <strong>English</strong> | <a href="./README.ja.md">日本語</a> <a href="./README.md">中文</a> | <strong>English</strong>
</p> </p>
<p align="center"> <p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p> </p>
<div align="center"> <div align="center">
@@ -57,13 +57,15 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use
> **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability. > **Note**: The latest image pulled from Docker may be an `alpha` release. Specify the version manually if you require stability.
## Features ## Features
1. Support for multiple large models: 1. Supports multiple API access channels:
+ [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] Official OpenAI channel (support proxy configuration)
+ [x] [Anthropic Claude Series Models](https://anthropic.com) + [x] **Azure OpenAI API**
+ [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [OpenAI-SB](https://openai-sb.com)
+ [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [API2D](https://api2d.com/r/197971)
+ [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
+ [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (invitation code: `OneAPI`)
+ [x] Custom channel: Various third-party proxy services not included in the list
2. Supports access to multiple channels through **load balancing**. 2. Supports access to multiple channels through **load balancing**.
3. Supports **stream mode** that enables typewriter-like effect through stream transmission. 3. Supports **stream mode** that enables typewriter-like effect through stream transmission.
4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details. 4. Supports **multi-machine deployment**. [See here](#multi-machine-deployment) for more details.
@@ -134,12 +136,12 @@ The initial account username is `root` and password is `123456`.
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# Build the frontend # Build the frontend
cd one-api/web/default cd one-api/web
npm install npm install
npm run build npm run build
# Build the backend # Build the backend
cd ../.. cd ..
go mod download go mod download
go build -ldflags "-s -w" -o one-api go build -ldflags "-s -w" -o one-api
``` ```
@@ -173,12 +175,7 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
<summary><strong>Deploy on Sealos</strong></summary> <summary><strong>Deploy on Sealos</strong></summary>
<div> <div>
> Sealos supports high concurrency, dynamic scaling, and stable operations for millions of users. Please refer to [this tutorial](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md).
> Click the button below to deploy with one click.👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div> </div>
</details> </details>
@@ -189,10 +186,8 @@ If you encounter a blank page after deployment, refer to [#97](https://github.co
> Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage. > Zeabur's servers are located overseas, automatically solving network issues, and the free quota is sufficient for personal usage.
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. First, fork the code. 1. First, fork the code.
2. Go to [Zeabur](https://zeabur.com?referralCode=songquanpeng), log in, and enter the console. 2. Go to [Zeabur](https://zeabur.com/), log in, and enter the console.
3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port). 3. Create a new project. In Service -> Add Service, select Marketplace, and choose MySQL. Note down the connection parameters (username, password, address, and port).
4. Copy the connection parameters and run ```create database `one-api` ``` to create the database. 4. Copy the connection parameters and run ```create database `one-api` ``` to create the database.
5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository. 5. Then, in Service -> Add Service, select Git (authorization is required for the first use) and choose your forked repository.
@@ -285,7 +280,7 @@ If the channel ID is not provided, load balancing will be used to distribute the
+ Double-check that your interface address and API Key are correct. + Double-check that your interface address and API Key are correct.
## Related Projects ## Related Projects
[FastGPT](https://github.com/labring/FastGPT): Knowledge question answering system based on the LLM [FastGPT](https://github.com/c121914yu/FastGPT): Build an AI knowledge base in three minutes
## Note ## Note
This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes. This project is an open-source project. Please use it in compliance with OpenAI's [Terms of Use](https://openai.com/policies/terms-of-use) and **applicable laws and regulations**. It must not be used for illegal purposes.

View File

@@ -1,300 +0,0 @@
<p align="right">
<a href="./README.md">中文</a> | <a href="./README.en.md">English</a> | <strong>日本語</strong>
</p>
<p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p>
<div align="center">
# One API
_✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_
</div>
<p align="center">
<a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE">
<img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://hub.docker.com/repository/docker/justsong/one-api">
<img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull">
</a>
<a href="https://github.com/songquanpeng/one-api/releases/latest">
<img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release">
</a>
<a href="https://goreportcard.com/report/github.com/songquanpeng/one-api">
<img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard">
</a>
</p>
<p align="center">
<a href="#deployment">デプロイチュートリアル</a>
·
<a href="#usage">使用方法</a>
·
<a href="https://github.com/songquanpeng/one-api/issues">フィードバック</a>
·
<a href="#screenshots">スクリーンショット</a>
·
<a href="https://openai.justsong.cn/">ライブデモ</a>
·
<a href="#faq">FAQ</a>
·
<a href="#related-projects">関連プロジェクト</a>
·
<a href="https://iamazing.cn/page/reward">寄付</a>
</p>
> **警告**: この README は ChatGPT によって翻訳されています。翻訳ミスを発見した場合は遠慮なく PR を投稿してください。
> **警告** 英語版の Docker イメージは `justsong/one-api-en` です。
> **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。
## 特徴
1. 複数の大型モデルをサポート:
+ [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
+ [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
+ [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google)
+ [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。
3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。
4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。
5. トークンの有効期限や使用回数を設定できる**トークン管理**に対応しています。
6. **バウチャー管理**に対応しており、バウチャーの一括生成やエクスポートが可能です。バウチャーは口座残高の補充に利用できます。
7. **チャンネル管理**に対応し、チャンネルの一括作成が可能。
8. グループごとに異なるレートを設定するための**ユーザーグループ**と**チャンネルグループ**をサポートしています。
9. チャンネル**モデルリスト設定**に対応。
10. **クォータ詳細チェック**をサポート。
11. **ユーザー招待報酬**をサポートします。
12. 米ドルでの残高表示が可能。
13. 新規ユーザー向けのお知らせ公開、リチャージリンク設定、初期残高設定に対応。
14. 豊富な**カスタマイズ**オプションを提供します:
1. システム名、ロゴ、フッターのカスタマイズが可能。
2. HTML と Markdown コードを使用したホームページとアバウトページのカスタマイズ、または iframe を介したスタンドアロンウェブページの埋め込みをサポートしています。
15. システム・アクセストークンによる管理 API アクセスをサポートする。
16. Cloudflare Turnstile によるユーザー認証に対応。
17. ユーザー管理と複数のユーザーログイン/登録方法をサポート:
+ 電子メールによるログイン/登録とパスワードリセット。
+ [GitHub OAuth](https://github.com/settings/applications/new)。
+ WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。
18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。
## デプロイメント
### Docker デプロイメント
デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`
コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`
`-p 3000:3000` の最初の `3000` はホストのポートで、必要に応じて変更できます。
データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。
Nginxリファレンス設定:
```
server{
server_name openai.justsong.cn; # ドメイン名は適宜変更
location / {
client_max_body_size 64m;
proxy_http_version 1.1;
proxy_pass http://localhost:3000; # それに応じてポートを変更
proxy_set_header Host $host;
proxy_set_header X-Forwarded-For $remote_addr;
proxy_cache_bypass $http_upgrade;
proxy_set_header Accept-Encoding gzip;
proxy_read_timeout 300s; # GPT-4 はより長いタイムアウトが必要
}
}
```
次に、Let's Encrypt certbot を使って HTTPS を設定します:
```bash
# Ubuntu に certbot をインストール:
sudo snap install --classic certbot
sudo ln -s /snap/bin/certbot /usr/bin/certbot
# 証明書の生成と Nginx 設定の変更
sudo certbot --nginx
# プロンプトに従う
# Nginx を再起動
sudo service nginx restart
```
初期アカウントのユーザー名は `root` で、パスワードは `123456` です。
### マニュアルデプロイ
1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする:
```shell
git clone https://github.com/songquanpeng/one-api.git
# フロントエンドのビルド
cd one-api/web/default
npm install
npm run build
# バックエンドのビルド
cd ../..
go mod download
go build -ldflags "-s -w" -o one-api
```
2. 実行:
```shell
chmod u+x one-api
./one-api --port 3000 --log-dir ./logs
```
3. [http://localhost:3000/](http://localhost:3000/) にアクセスし、ログインする。初期アカウントのユーザー名は `root`、パスワードは `123456` である。
より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。
### マルチマシンデプロイ
1. すべてのサーバに同じ `SESSION_SECRET` を設定する。
2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。
3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。
4. データベースから定期的に設定を同期するサーバーには `SYNC_FREQUENCY` を設定する。
5. マスター以外のノードでは、オプションで `FRONTEND_BASE_URL` を設定して、ページ要求をマスターサーバーにリダイレクトすることができます。
6. マスター以外のノードには Redis を個別にインストールし、`REDIS_CONN_STRING` を設定して、キャッシュの有効期限が切れていないときにデータベースにゼロレイテンシーでアクセスできるようにする。
7. メインサーバーでもデータベースへのアクセスが高レイテンシになる場合は、Redis を有効にし、`SYNC_FREQUENCY` を設定してデータベースから定期的に設定を同期する必要がある。
Please refer to the [environment variables](#environment-variables) section for details on using environment variables.
### コントロールパネル(例: Baotaへの展開
詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。
配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。
### サードパーティプラットフォームへのデプロイ
<details>
<summary><strong>Sealos へのデプロイ</strong></summary>
<div>
> Sealos は、高い同時実行性、ダイナミックなスケーリング、数百万人のユーザーに対する安定した運用をサポートしています。
> 下のボタンをクリックすると、ワンクリックで展開できます。👇
[![](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div>
</details>
<details>
<summary><strong>Zeabur へのデプロイ</strong></summary>
<div>
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. まず、コードをフォークする。
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。
5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。
6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。
7. 再デプロイを選択します。
8. Domains タブで、"my-one-api" のような適切なドメイン名の接頭辞を選択する。最終的なドメイン名は "my-one-api.zeabur.app" となります。独自のドメイン名を CNAME することもできます。
9. デプロイが完了するのを待ち、生成されたドメイン名をクリックして One API にアクセスします。
</div>
</details>
## コンフィグ
システムは箱から出してすぐに使えます。
環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。
システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。
## 使用方法
`Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。
アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。
OpenAI API が使用されている場所では、API Base に One API のデプロイアドレスを設定することを忘れないでください(例: `https://openai.justsong.cn`。API Key は One API で生成されたトークンでなければなりません。
具体的な API Base のフォーマットは、使用しているクライアントに依存することに注意してください。
```mermaid
graph LR
A(ユーザ)
A --->|リクエスト| B(One API)
B -->|中継リクエスト| C(OpenAI)
B -->|中継リクエスト| D(Azure)
B -->|中継リクエスト| E(その他のダウンストリームチャンネル)
```
現在のリクエストにどのチャネルを使うかを指定するには、トークンの後に チャネル ID を追加します: 例えば、`Authorization: Bearer ONE_API_KEY-CHANNEL_ID` のようにします。
チャンネル ID を指定するためには、トークンは管理者によって作成される必要があることに注意してください。
もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。
### 環境変数
1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。
+ 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。
+ 例: `SESSION_SECRET=random_string`
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
+ 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
+ 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
+ 例: `SYNC_FREQUENCY=60`
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master``slave` である。設定されていない場合、デフォルトは `master`
+ 例: `NODE_TYPE=slave`
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
+ 例: `CHANNEL_UPDATE_FREQUENCY=1440`
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
+ 例: `CHANNEL_TEST_FREQUENCY=1440`
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
+ 例: `POLLING_INTERVAL=5`
### コマンドラインパラメータ
1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。
+ 例: `--port 3000`
2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。
+ 例: `--log-dir ./logs`
3. `--version`: システムのバージョン番号を表示して終了する。
4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。
## スクリーンショット
![channel](https://user-images.githubusercontent.com/39998050/233837954-ae6683aa-5c4f-429f-a949-6645a83c9490.png)
![token](https://user-images.githubusercontent.com/39998050/233837971-dab488b7-6d96-43af-b640-a168e8d1c9bf.png)
## FAQ
1. ルマとは何かどのように計算されますかOne API にはノルマ計算の問題はありますか?
+ ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率)
+ 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。
+ ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。
2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか?
+ トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。
+ トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。
3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか?
+ ユーザーとチャンネルグループの設定を確認してください。
+ チャンネルモデルの設定も確認してください。
4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value"
+ このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。
+ ほとんどの場合、デプロイサイトのIPかプロキシのードが CloudFlare によってブロックされています。
5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch"
+ デプロイ時に `BASE_URL` を設定しないでください。
+ インターフェイスアドレスと API Key が正しいか再確認してください。
## 関連プロジェクト
[FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム
## 注
本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。
このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。
このプロジェクトを基にした派生プロジェクトについても同様です。
帰属表示を含めたくない場合は、事前に許可を得なければなりません。
MIT ライセンスによると、このプロジェクトを利用するリスクと責任は利用者が負うべきであり、このオープンソースプロジェクトの開発者は責任を負いません。

188
README.md
View File

@@ -1,10 +1,10 @@
<p align="right"> <p align="right">
<strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a> <strong>中文</strong> | <a href="./README.en.md">English</a>
</p> </p>
<p align="center"> <p align="center">
<a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/default/public/logo.png" width="150" height="150" alt="one-api logo"></a> <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
</p> </p>
<div align="center"> <div align="center">
@@ -51,32 +51,27 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
<a href="https://iamazing.cn/page/reward">赞赏支持</a> <a href="https://iamazing.cn/page/reward">赞赏支持</a>
</p> </p>
> [!NOTE] > **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
>
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
> [!WARNING] > **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
> [!WARNING] > **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
> 使用 root 用户初次登录系统后,务必修改默认密码 `123456`
## 功能 ## 功能
1. 支持多种大模型: 1. 支持多种大模型:
+ [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)
+ [x] [Anthropic Claude 系列模型](https://anthropic.com) + [x] [Anthropic Claude 系列模型](https://anthropic.com)
+ [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [Google PaLM2 系列模型](https://developers.generativeai.google)
+ [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
+ [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
+ [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
+ [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn) + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
+ [x] [360 智脑](https://ai.360.cn) 2. 支持配置镜像以及众多第三方代理服务:
+ [x] [腾讯混元大模型](https://cloud.tencent.com/document/product/1729) + [x] [API Distribute](https://api.gptjk.top/register?aff=QGxj)
+ [x] [Moonshot AI](https://platform.moonshot.cn/) + [x] [OpenAI-SB](https://openai-sb.com)
+ [ ] [字节云雀大模型](https://www.volcengine.com/product/ark) (WIP) + [x] [API2D](https://api2d.com/r/197971)
+ [ ] [MINIMAX](https://api.minimax.chat/) (WIP) + [x] [OhMyGPT](https://aigptx.top?aff=uFpUl2Kf)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 + [x] [AI Proxy](https://aiproxy.io/?i=OneAPI) (邀请码:`OneAPI`
+ [x] [CloseAI](https://console.closeai-asia.com/r/2412)
+ [x] 自定义渠道:例如各种未收录的第三方代理服务
3. 支持通过**负载均衡**的方式访问多个渠道。 3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。 5. 支持**多机部署**[详见此处](#多机部署)。
@@ -89,43 +84,33 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
12. 支持**用户邀请奖励**。 12. 支持**用户邀请奖励**。
13. 支持以美元为单位显示额度。 13. 支持以美元为单位显示额度。
14. 支持发布公告,设置充值链接,设置新用户初始额度。 14. 支持发布公告,设置充值链接,设置新用户初始额度。
15. 支持模型映射,重定向用户的请求模型,如无必要请不要设置,设置之后会导致请求体被重新构造而非直接透传,会导致部分还未正式支持的字段无法传递成功 15. 支持模型映射,重定向用户的请求模型。
16. 支持失败自动重试。 16. 支持失败自动重试。
17. 支持绘图接口。 17. 支持绘图接口。
18. 支持 [Cloudflare AI Gateway](https://developers.cloudflare.com/ai-gateway/providers/openai/),渠道设置的代理部分填写 `https://gateway.ai.cloudflare.com/v1/ACCOUNT_TAG/GATEWAY/openai` 即可。 18. 支持丰富的**自定义**设置,
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。 1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。 2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法) 19. 支持通过系统访问令牌访问管理 API。
21. 支持 Cloudflare Turnstile 用户校验。 20. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式** 21. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。 + 邮箱登录注册以及通过邮箱进行密码重置。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。 + [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。 + 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
## 部署 ## 部署
### 基于 Docker 进行部署 ### 基于 Docker 进行部署
```shell 部署命令:`docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api`
# 使用 SQLite 的部署命令:
docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
# 使用 MySQL 的部署命令,在上面的基础上添加 `-e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi"`,请自行修改数据库连接参数,不清楚如何修改请参见下面环境变量一节。
# 例如:
docker run --name one-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(localhost:3306)/oneapi" -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api
```
其中,`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据和日志将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。 如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。 如果你的并发量较大,推荐设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR` 更新命令:`docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR`
`-p 3000:3000` 中的第一个 `3000` 是宿主机的端口,可以根据需要进行修改。
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
Nginx 的参考配置: Nginx 的参考配置:
``` ```
server{ server{
@@ -158,31 +143,18 @@ sudo service nginx restart
初始账号用户名为 `root`,密码为 `123456` 初始账号用户名为 `root`,密码为 `123456`
### 基于 Docker Compose 进行部署
> 仅启动方式不同,参数设置不变,请参考基于 Docker 部署部分
```shell
# 目前支持 MySQL 启动,数据存储在 ./data/mysql 文件夹内
docker-compose up -d
# 查看部署状态
docker-compose ps
```
### 手动部署 ### 手动部署
1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译: 1. 从 [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) 下载可执行文件或者从源码编译:
```shell ```shell
git clone https://github.com/songquanpeng/one-api.git git clone https://github.com/songquanpeng/one-api.git
# 构建前端 # 构建前端
cd one-api/web/default cd one-api/web
npm install npm install
npm run build npm run build
# 构建后端 # 构建后端
cd ../.. cd ..
go mod download go mod download
go build -ldflags "-s -w" -o one-api go build -ldflags "-s -w" -o one-api
```` ````
@@ -233,23 +205,14 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。 注意修改端口号、`OPENAI_API_BASE_URL` 和 `OPENAI_API_KEY`。
#### QChatGPT - QQ机器人
项目主页https://github.com/RockChinQ/QChatGPT
根据文档完成部署后,在`config.py`设置配置项`openai_config`的`reverse_proxy`为 One API 后端地址,设置`api_key`为 One API 生成的key并在配置项`completion_api_params`的`model`参数设置为 One API 支持的模型名称。
可安装 [Switcher 插件](https://github.com/RockChinQ/Switcher)在运行时切换所使用的模型。
### 部署到第三方平台 ### 部署到第三方平台
<details> <details>
<summary><strong>部署到 Sealos </strong></summary> <summary><strong>部署到 Sealos </strong></summary>
<div> <div>
> Sealos 的服务器在国外,不需要额外处理网络问题,支持高并发 & 动态伸缩 > Sealos 可视化部署,仅需 1 分钟
点击以下按钮一键部署(部署后访问出现 404 请等待 3~5 分钟): 参考这个[教程](https://github.com/c121914yu/FastGPT/blob/main/docs/deploy/one-api/sealos.md)中 1~5 步。
[![Deploy-on-Sealos.svg](https://raw.githubusercontent.com/labring-actions/templates/main/Deploy-on-Sealos.svg)](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
</div> </div>
</details> </details>
@@ -258,12 +221,10 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
<summary><strong>部署到 Zeabur</strong></summary> <summary><strong>部署到 Zeabur</strong></summary>
<div> <div>
> Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用 > Zeabur 的服务器在国外,自动解决了网络的问题,同时免费的额度也足够个人使用
[![Deploy on Zeabur](https://zeabur.com/button.svg)](https://zeabur.com/templates/7Q0KO3)
1. 首先 fork 一份代码。 1. 首先 fork 一份代码。
2. 进入 [Zeabur](https://zeabur.com?referralCode=songquanpeng),登录,进入控制台。 2. 进入 [Zeabur](https://zeabur.com/),登录,进入控制台。
3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口 3. 新建一个 Project在 Service -> Add Service 选择 Marketplace选择 MySQL并记下连接参数用户名、密码、地址、端口
4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。 4. 复制链接参数,运行 ```create database `one-api` ``` 创建数据库。
5. 然后在 Service -> Add Service选择 Git第一次使用需要先授权选择你 fork 的仓库。 5. 然后在 Service -> Add Service选择 Git第一次使用需要先授权选择你 fork 的仓库。
@@ -275,17 +236,6 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope
</div> </div>
</details> </details>
<details>
<summary><strong>部署到 Render</strong></summary>
<div>
> Render 提供免费额度,绑卡后可以进一步提升额度
Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashboard.render.com
</div>
</details>
## 配置 ## 配置
系统本身开箱即用。 系统本身开箱即用。
@@ -304,20 +254,13 @@ Render 可以直接部署 docker 镜像,不需要 fork 仓库https://dashbo
注意,具体的 API Base 的格式取决于你所使用的客户端。 注意,具体的 API Base 的格式取决于你所使用的客户端。
例如对于 OpenAI 的官方库:
```bash
OPENAI_API_KEY="sk-xxxxxx"
OPENAI_API_BASE="https://<HOST>:<PORT>/v1"
```
```mermaid ```mermaid
graph LR graph LR
A(用户) A(用户)
A --->|使用 One API 分发的 key 进行请求| B(One API) A --->|请求| B(One API)
B -->|中继请求| C(OpenAI) B -->|中继请求| C(OpenAI)
B -->|中继请求| D(Azure) B -->|中继请求| D(Azure)
B -->|中继请求| E(其他 OpenAI API 格式下游渠道) B -->|中继请求| E(其他下游渠道)
B -->|中继并修改请求体和返回体| F(非 OpenAI API 格式下游渠道)
``` ```
可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。 可以通过在令牌后面添加渠道 ID 的方式指定使用哪一个渠道处理本次请求,例如:`Authorization: Bearer ONE_API_KEY-CHANNEL_ID`。
@@ -326,57 +269,32 @@ graph LR
不加的话将会使用负载均衡的方式使用多个渠道。 不加的话将会使用负载均衡的方式使用多个渠道。
### 环境变量 ### 环境变量
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用 1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储
+ 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
+ 如果数据库访问延迟很低,没有必要启用 Redis启用后反而会出现数据滞后的问题。
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。 2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
+ 例子:`SESSION_SECRET=random_string` + 例子:`SESSION_SECRET=random_string`
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 或 PostgreSQL 3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite请使用 MySQL 8.0 版本
+ 例子: + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ MySQL`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
+ PostgreSQL`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈)
+ 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。 + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
+ 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。 + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
+ 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。 + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
+ 请根据你的数据库配置修改下列参数(或者保持默认值):
+ `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。
+ `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
+ 如果报错 `Error 1040: Too many connections`,请适当减小该值。
+ `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。 4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
+ 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn` + 例子:`FRONTEND_BASE_URL=https://openai.justsong.cn`
5. `MEMORY_CACHE_ENABLED`:启用内存缓存,会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false` 5. `SYNC_FREQUENCY`:设置之后将定期与数据库同步配置,单位为秒,未设置则不进行同步
+ 例子:`MEMORY_CACHE_ENABLED=true`
6. `SYNC_FREQUENCY`:在启用缓存的情况下与数据库同步配置的频率,单位为秒,默认为 `600` 秒。
+ 例子:`SYNC_FREQUENCY=60` + 例子:`SYNC_FREQUENCY=60`
7. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。 6. `NODE_TYPE`:设置之后将指定节点类型,可选值为 `master` 和 `slave`,未设置则默认为 `master`。
+ 例子:`NODE_TYPE=slave` + 例子:`NODE_TYPE=slave`
8. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 7. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
9. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 8. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440` + 例子:`CHANNEL_TEST_FREQUENCY=1440`
10. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
11. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
12. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5`
13. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
14. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
15. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
16. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
17. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
18. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
+ 例子:`--port 3000` + 例子:`--port 3000`
2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,默认保存至工作目录的 `logs` 文件夹下 2. `--log-dir <log_dir>`: 指定日志文件夹,如果没有设置,日志将不会被保存
+ 例子:`--log-dir ./logs` + 例子:`--log-dir ./logs`
3. `--version`: 打印系统版本号并退出。 3. `--version`: 打印系统版本号并退出。
4. `--help`: 查看命令的使用帮助和参数说明。 4. `--help`: 查看命令的使用帮助和参数说明。
@@ -395,7 +313,6 @@ https://openai.justsong.cn
+ 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率) + 额度 = 分组倍率 * 模型倍率 * (提示 token 数 + 补全 token 数 * 补全倍率)
+ 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。 + 其中补全倍率对于 GPT3.5 固定为 1.33GPT4 为 2与官方保持一致。
+ 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。 + 如果是非流模式,官方接口会返回消耗的总 token但是你要注意提示和补全的消耗倍率不一样。
+ 注意One API 的默认倍率就是官方倍率,是已经调整过的。
2. 账户额度足够为什么提示额度不足? 2. 账户额度足够为什么提示额度不足?
+ 请检查你的令牌额度是否足够,这个和账户额度是分开的。 + 请检查你的令牌额度是否足够,这个和账户额度是分开的。
+ 令牌额度仅供用户设置最大使用量,用户可自由设置。 + 令牌额度仅供用户设置最大使用量,用户可自由设置。
@@ -408,22 +325,11 @@ https://openai.justsong.cn
5. ChatGPT Next Web 报错:`Failed to fetch` 5. ChatGPT Next Web 报错:`Failed to fetch`
+ 部署的时候不要设置 `BASE_URL`。 + 部署的时候不要设置 `BASE_URL`。
+ 检查你的接口地址和 API Key 有没有填对。 + 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试` 6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游通道 429 了。 + 上游通道 429 了。
7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
8. 升级之前数据库需要做变更吗?
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的通道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的通道。
+ 对于每一个通道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该通道支持该模型。
## 相关项目 ## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统 [FastGPT](https://github.com/c121914yu/FastGPT): 三分钟搭建 AI 知识库
* [ChatGPT Next Web](https://github.com/Yidadaa/ChatGPT-Next-Web): 一键拥有你自己的跨平台 ChatGPT 应用
## 注意 ## 注意
@@ -431,4 +337,4 @@ https://openai.justsong.cn
同样适用于基于本项目的二开项目。 同样适用于基于本项目的二开项目。
依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。 依据 MIT 协议,使用者需自行承担使用本项目的风险与责任,本开源项目开发者与此无关。

View File

@@ -1,127 +0,0 @@
package config
import (
"github.com/songquanpeng/one-api/common/helper"
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var WeChatAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var EmailDomainRestrictionEnabled = false
var EmailDomainWhitelist = []string{
"gmail.com",
"163.com",
"126.com",
"qq.com",
"outlook.com",
"hotmail.com",
"icloud.com",
"yahoo.com",
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
}
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute

View File

@@ -1,9 +1,92 @@
package common package common
import "time" import (
"os"
"strconv"
"sync"
"time"
"github.com/google/uuid"
)
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
var SystemName = "One API"
var ServerAddress = "http://localhost:3000"
var Footer = ""
var Logo = ""
var TopUpLink = ""
var ChatLink = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true
var UsingSQLite = false
var UsingPostgreSQL = false
// Any options with "Secret", "Token" in its key won't be return by GetOptions
var SessionSecret = uuid.New().String()
var SQLitePath = "one-api.db"
var OptionMap map[string]string
var OptionMapRWMutex sync.RWMutex
var ItemsPerPage = 10
var MaxRecentItems = 100
var PasswordLoginEnabled = true
var PasswordRegisterEnabled = true
var EmailVerificationEnabled = false
var GitHubOAuthEnabled = false
var DiscordOAuthEnabled = false
var WeChatAuthEnabled = false
var GoogleOAuthEnabled = false
var TurnstileCheckEnabled = false
var RegisterEnabled = true
var LogConsumeEnabled = true
var SMTPServer = ""
var SMTPPort = 587
var SMTPAccount = ""
var SMTPFrom = ""
var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var DiscordClientId = ""
var DiscordClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
var GoogleClientId = ""
var GoogleClientSecret = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
const ( const (
RoleGuestUser = 0 RoleGuestUser = 0
@@ -12,6 +95,34 @@ const (
RoleRootUser = 100 RoleRootUser = 100
) )
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = 180
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = 60
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const ( const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0 UserStatusDisabled = 2 // also don't use 0
@@ -31,39 +142,39 @@ const (
) )
const ( const (
ChannelStatusUnknown = 0 ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0 ChannelStatusDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
) )
const ( const (
ChannelTypeUnknown = 0 ChannelAllowNonStreamEnabled = 1
ChannelTypeOpenAI = 1 ChannelAllowNonStreamDisabled = 2
ChannelTypeAPI2D = 2 )
ChannelTypeAzure = 3
ChannelTypeCloseAI = 4 const (
ChannelTypeOpenAISB = 5 ChannelAllowStreamEnabled = 1
ChannelTypeOpenAIMax = 6 ChannelAllowStreamDisabled = 2
ChannelTypeOhMyGPT = 7 )
ChannelTypeCustom = 8
ChannelTypeAILS = 9 const (
ChannelTypeAIProxy = 10 ChannelTypeUnknown = 0
ChannelTypePaLM = 11 ChannelTypeOpenAI = 1
ChannelTypeAPI2GPT = 12 ChannelTypeAPI2D = 2
ChannelTypeAIGC2D = 13 ChannelTypeAzure = 3
ChannelTypeAnthropic = 14 ChannelTypeCloseAI = 4
ChannelTypeBaidu = 15 ChannelTypeOpenAISB = 5
ChannelTypeZhipu = 16 ChannelTypeOpenAIMax = 6
ChannelTypeAli = 17 ChannelTypeOhMyGPT = 7
ChannelTypeXunfei = 18 ChannelTypeCustom = 8
ChannelType360 = 19 ChannelTypeAILS = 9
ChannelTypeOpenRouter = 20 ChannelTypeAIProxy = 10
ChannelTypeAIProxyLibrary = 21 ChannelTypePaLM = 11
ChannelTypeFastGPT = 22 ChannelTypeAPI2GPT = 12
ChannelTypeTencent = 23 ChannelTypeAIGC2D = 13
ChannelTypeGemini = 24 ChannelTypeAnthropic = 14
ChannelTypeMoonshot = 25 ChannelTypeBaidu = 15
ChannelTypeZhipu = 16
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@@ -78,27 +189,10 @@ var ChannelBaseURLs = []string{
"", // 8 "", // 8
"https://api.caipacity.com", // 9 "https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10 "https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11 "", // 11
"https://api.api2gpt.com", // 12 "https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13 "https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14 "https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15 "https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16 "https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
} }
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@@ -1,9 +0,0 @@
package common
import "github.com/songquanpeng/one-api/common/helper"
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,57 +1,37 @@
package common package common
import ( import (
"crypto/rand"
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config"
"net/smtp" "net/smtp"
"strings" "strings"
"time"
) )
func SendEmail(subject string, receiver string, content string) error { func SendEmail(subject string, receiver string, content string) error {
if config.SMTPFrom == "" { // for compatibility if SMTPFrom == "" { // for compatibility
config.SMTPFrom = config.SMTPAccount SMTPFrom = SMTPAccount
} }
encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject)))
// Extract domain from SMTPFrom
parts := strings.Split(config.SMTPFrom, "@")
var domain string
if len(parts) > 1 {
domain = parts[1]
}
// Generate a unique Message-ID
buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
return err
}
messageId := fmt.Sprintf("<%x@%s>", buf, domain)
mail := []byte(fmt.Sprintf("To: %s\r\n"+ mail := []byte(fmt.Sprintf("To: %s\r\n"+
"From: %s<%s>\r\n"+ "From: %s<%s>\r\n"+
"Subject: %s\r\n"+ "Subject: %s\r\n"+
"Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322
"Date: %s\r\n"+
"Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n",
receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) receiver, SystemName, SMTPFrom, encodedSubject, content))
auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer)
addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort)
to := strings.Split(receiver, ";") to := strings.Split(receiver, ";")
var err error
if config.SMTPPort == 465 { if SMTPPort == 465 {
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: config.SMTPServer, ServerName: SMTPServer,
} }
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig)
if err != nil { if err != nil {
return err return err
} }
client, err := smtp.NewClient(conn, config.SMTPServer) client, err := smtp.NewClient(conn, SMTPServer)
if err != nil { if err != nil {
return err return err
} }
@@ -59,7 +39,7 @@ func SendEmail(subject string, receiver string, content string) error {
if err = client.Auth(auth); err != nil { if err = client.Auth(auth); err != nil {
return err return err
} }
if err = client.Mail(config.SMTPFrom); err != nil { if err = client.Mail(SMTPFrom); err != nil {
return err return err
} }
receiverEmails := strings.Split(receiver, ";") receiverEmails := strings.Split(receiver, ";")
@@ -81,7 +61,7 @@ func SendEmail(subject string, receiver string, content string) error {
return err return err
} }
} else { } else {
err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) err = smtp.SendMail(addr, auth, SMTPAccount, to, mail)
} }
return err return err
} }

View File

@@ -15,7 +15,10 @@ type embedFileSystem struct {
func (e embedFileSystem) Exists(prefix string, path string) bool { func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path) _, err := e.Open(path)
return err == nil if err != nil {
return false
}
return true
} }
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem { func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {

View File

@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"strings"
) )
func UnmarshalBodyReusable(c *gin.Context, v any) error { func UnmarshalBodyReusable(c *gin.Context, v any) error {
@@ -17,13 +16,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
if err != nil { if err != nil {
return err return err
} }
contentType := c.Request.Header.Get("Content-Type") err = json.Unmarshal(requestBody, &v)
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
} else {
// skip for now
// TODO: someday non json request have variant model, we will need to implementation this
}
if err != nil { if err != nil {
return err return err
} }
@@ -31,11 +24,3 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil return nil
} }
func SetEventStreamHeaders(c *gin.Context) {
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
}

View File

@@ -1,9 +1,6 @@
package common package common
import ( import "encoding/json"
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
)
var GroupRatio = map[string]float64{ var GroupRatio = map[string]float64{
"default": 1, "default": 1,
@@ -14,7 +11,7 @@ var GroupRatio = map[string]float64{
func GroupRatio2JSONString() string { func GroupRatio2JSONString() string {
jsonBytes, err := json.Marshal(GroupRatio) jsonBytes, err := json.Marshal(GroupRatio)
if err != nil { if err != nil {
logger.SysError("error marshalling model ratio: " + err.Error()) SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }
@@ -27,7 +24,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error {
func GetGroupRatio(name string) float64 { func GetGroupRatio(name string) float64 {
ratio, ok := GroupRatio[name] ratio, ok := GroupRatio[name]
if !ok { if !ok {
logger.SysError("group ratio not found: " + name) SysError("group ratio not found: " + name)
return 1 return 1
} }
return ratio return ratio

View File

@@ -1,224 +0,0 @@
package helper
import (
"fmt"
"github.com/google/uuid"
"github.com/songquanpeng/one-api/common/logger"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
return inter
case int:
return fmt.Sprintf("%d", inter)
case float64:
return fmt.Sprintf("%f", inter)
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value
}
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
return fmt.Sprintf("%s (request id: %s)", message, id)
}
func String2Int(str string) int {
num, err := strconv.Atoi(str)
if err != nil {
return 0
}
return num
}

View File

@@ -1,111 +0,0 @@
package image
import (
"bytes"
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"net/http"
"regexp"
"strings"
"sync"
_ "golang.org/x/image/webp"
)
// Regex to match data URL pattern
var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`)
func IsImageUrl(url string) (bool, error) {
resp, err := http.Head(url)
if err != nil {
return false, err
}
if !strings.HasPrefix(resp.Header.Get("Content-Type"), "image/") {
return false, nil
}
return true, nil
}
func GetImageSizeFromUrl(url string) (width int, height int, err error) {
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
img, _, err := image.DecodeConfig(resp.Body)
if err != nil {
return
}
return img.Width, img.Height, nil
}
func GetImageFromUrl(url string) (mimeType string, data string, err error) {
// Check if the URL is a data URL
matches := dataURLPattern.FindStringSubmatch(url)
if len(matches) == 3 {
// URL is a data URL
mimeType = "image/" + matches[1]
data = matches[2]
return
}
isImage, err := IsImageUrl(url)
if !isImage {
return
}
resp, err := http.Get(url)
if err != nil {
return
}
defer resp.Body.Close()
buffer := bytes.NewBuffer(nil)
_, err = buffer.ReadFrom(resp.Body)
if err != nil {
return
}
mimeType = resp.Header.Get("Content-Type")
data = base64.StdEncoding.EncodeToString(buffer.Bytes())
return
}
var (
reg = regexp.MustCompile(`data:image/([^;]+);base64,`)
)
var readerPool = sync.Pool{
New: func() interface{} {
return &bytes.Reader{}
},
}
func GetImageSizeFromBase64(encoded string) (width int, height int, err error) {
decoded, err := base64.StdEncoding.DecodeString(reg.ReplaceAllString(encoded, ""))
if err != nil {
return 0, 0, err
}
reader := readerPool.Get().(*bytes.Reader)
defer readerPool.Put(reader)
reader.Reset(decoded)
img, _, err := image.DecodeConfig(reader)
if err != nil {
return 0, 0, err
}
return img.Width, img.Height, nil
}
func GetImageSize(image string) (width int, height int, err error) {
if strings.HasPrefix(image, "data:image/") {
return GetImageSizeFromBase64(image)
}
return GetImageSizeFromUrl(image)
}

View File

@@ -1,171 +0,0 @@
package image_test
import (
"encoding/base64"
"image"
_ "image/gif"
_ "image/jpeg"
_ "image/png"
"io"
"net/http"
"strconv"
"strings"
"testing"
img "github.com/songquanpeng/one-api/common/image"
"github.com/stretchr/testify/assert"
_ "golang.org/x/image/webp"
)
type CountingReader struct {
reader io.Reader
BytesRead int
}
func (r *CountingReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.BytesRead += n
return n, err
}
var (
cases = []struct {
url string
format string
width int
height int
}{
{"https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", "jpeg", 2560, 1669},
{"https://upload.wikimedia.org/wikipedia/commons/9/97/Basshunter_live_performances.png", "png", 4500, 2592},
{"https://upload.wikimedia.org/wikipedia/commons/c/c6/TO_THE_ONE_SOMETHINGNESS.webp", "webp", 984, 985},
{"https://upload.wikimedia.org/wikipedia/commons/d/d0/01_Das_Sandberg-Modell.gif", "gif", 1917, 1533},
{"https://upload.wikimedia.org/wikipedia/commons/6/62/102Cervus.jpg", "jpeg", 270, 230},
}
)
func TestDecode(t *testing.T) {
// Bytes read: varies sometimes
// jpeg: 1063892
// png: 294462
// webp: 99529
// gif: 956153
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 4096
// png: 4096
// webp: 4096
// gif: 4096
// jpeg#01: 4096
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
reader := &CountingReader{reader: resp.Body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestBase64(t *testing.T) {
// Bytes read:
// jpeg: 1063892
// png: 294462
// webp: 99072
// gif: 953856
// jpeg#01: 32805
for _, c := range cases {
t.Run("Decode:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
img, format, err := image.Decode(reader)
assert.NoError(t, err)
size := img.Bounds().Size()
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, size.X)
assert.Equal(t, c.height, size.Y)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
// Bytes read:
// jpeg: 1536
// png: 768
// webp: 768
// gif: 1536
// jpeg#01: 3840
for _, c := range cases {
t.Run("DecodeConfig:"+c.format, func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
body := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded))
reader := &CountingReader{reader: body}
config, format, err := image.DecodeConfig(reader)
assert.NoError(t, err)
assert.Equal(t, c.format, format)
assert.Equal(t, c.width, config.Width)
assert.Equal(t, c.height, config.Height)
t.Logf("Bytes read: %d", reader.BytesRead)
})
}
}
func TestGetImageSize(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
width, height, err := img.GetImageSize(c.url)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}
func TestGetImageSizeFromBase64(t *testing.T) {
for i, c := range cases {
t.Run("Decode:"+strconv.Itoa(i), func(t *testing.T) {
resp, err := http.Get(c.url)
assert.NoError(t, err)
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
encoded := base64.StdEncoding.EncodeToString(data)
width, height, err := img.GetImageSizeFromBase64(encoded)
assert.NoError(t, err)
assert.Equal(t, c.width, width)
assert.Equal(t, c.height, height)
})
}
}

View File

@@ -3,8 +3,6 @@ package common
import ( import (
"flag" "flag"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,7 +12,7 @@ var (
Port = flag.Int("port", 3000, "the listening port") Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit") PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit") PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "./logs", "specify the log directory") LogDir = flag.String("log-dir", "", "specify the log directory")
) )
func printHelp() { func printHelp() {
@@ -38,11 +36,7 @@ func init() {
} }
if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" { SessionSecret = os.Getenv("SESSION_SECRET")
logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
config.SessionSecret = os.Getenv("SESSION_SECRET")
}
} }
if os.Getenv("SQLITE_PATH") != "" { if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH") SQLitePath = os.Getenv("SQLITE_PATH")
@@ -59,6 +53,5 @@ func init() {
log.Fatal(err) log.Fatal(err)
} }
} }
logger.LogDir = *LogDir
} }
} }

52
common/logger.go Normal file
View File

@@ -0,0 +1,52 @@
package common
import (
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"time"
)
func SetupGinLog() {
if *LogDir != "" {
commonLogPath := filepath.Join(*LogDir, "common.log")
errorLogPath := filepath.Join(*LogDir, "error.log")
commonFd, err := os.OpenFile(commonLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
errorFd, err := os.OpenFile(errorLogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, commonFd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, errorFd)
}
}
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}
func LogQuota(quota int) string {
if DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/QuotaPerUnit)
} else {
return fmt.Sprintf("%d 点额度", quota)
}
}

View File

@@ -1,7 +0,0 @@
package logger
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var LogDir string

View File

@@ -1,104 +0,0 @@
package logger
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
)
const (
loggerINFO = "INFO"
loggerWarn = "WARN"
loggerError = "ERR"
)
const maxLogCount = 1000000
var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
func SetupLogger() {
if LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
}
func SysLog(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func SysError(s string) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s)
}
func Info(ctx context.Context, msg string) {
logHelper(ctx, loggerINFO, msg)
}
func Warn(ctx context.Context, msg string) {
logHelper(ctx, loggerWarn, msg)
}
func Error(ctx context.Context, msg string) {
logHelper(ctx, loggerError, msg)
}
func Infof(ctx context.Context, format string, a ...any) {
Info(ctx, fmt.Sprintf(format, a...))
}
func Warnf(ctx context.Context, format string, a ...any) {
Warn(ctx, fmt.Sprintf(format, a...))
}
func Errorf(ctx context.Context, format string, a ...any) {
Error(ctx, fmt.Sprintf(format, a...))
}
func logHelper(ctx context.Context, level string, msg string) {
writer := gin.DefaultErrorWriter
if level == loggerINFO {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
logCount++ // we don't need accurate count, so no lock here
if logCount > maxLogCount && !setupLogWorking {
logCount = 0
setupLogWorking = true
go func() {
SetupLogger()
}()
}
}
func FatalLog(v ...any) {
t := time.Now()
_, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v)
os.Exit(1)
}

View File

@@ -1,40 +1,6 @@
package common package common
import ( import "encoding/json"
"encoding/json"
"github.com/songquanpeng/one-api/common/logger"
"strings"
"time"
)
var DalleSizeRatios = map[string]map[string]float64{
"dall-e-2": {
"256x256": 1,
"512x512": 1.125,
"1024x1024": 1.25,
},
"dall-e-3": {
"1024x1024": 1,
"1024x1792": 2,
"1792x1024": 2,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
}
const (
USD2RMB = 7
USD = 500 // $0.002 = 1 -> $1 = 500
RMB = USD / USD2RMB
)
// ModelRatio // ModelRatio
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
@@ -44,27 +10,17 @@ const (
// 1 === $0.002 / 1K tokens // 1 === $0.002 / 1K tokens
// 1 === ¥0.014 / 1k tokens // 1 === ¥0.014 / 1k tokens
var ModelRatio = map[string]float64{ var ModelRatio = map[string]float64{
// https://openai.com/pricing
"gpt-4": 15, "gpt-4": 15,
"gpt-4-0314": 15, "gpt-4-0314": 15,
"gpt-4-0613": 15, "gpt-4-0613": 15,
"gpt-4-32k": 30, "gpt-4-32k": 30,
"gpt-4-32k-0314": 30, "gpt-4-32k-0314": 30,
"gpt-4-32k-0613": 30, "gpt-4-32k-0613": 30,
"gpt-4-1106-preview": 5, // $0.01 / 1K tokens
"gpt-4-0125-preview": 5, // $0.01 / 1K tokens
"gpt-4-turbo-preview": 5, // $0.01 / 1K tokens
"gpt-4-vision-preview": 5, // $0.01 / 1K tokens
"gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-0301": 0.75, "gpt-3.5-turbo-0301": 0.75,
"gpt-3.5-turbo-0613": 0.75, "gpt-3.5-turbo-0613": 0.75,
"gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens "gpt-3.5-turbo-16k": 1.5, // $0.003 / 1K tokens
"gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-16k-0613": 1.5,
"gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens
"gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens
"gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens
"davinci-002": 1, // $0.002 / 1K tokens
"babbage-002": 0.2, // $0.0004 / 1K tokens
"text-ada-001": 0.2, "text-ada-001": 0.2,
"text-babbage-001": 0.25, "text-babbage-001": 0.25,
"text-curie-001": 1, "text-curie-001": 1,
@@ -72,63 +28,30 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10, "text-davinci-003": 10,
"text-davinci-edit-001": 10, "text-davinci-edit-001": 10,
"code-davinci-edit-001": 10, "code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens "whisper-1": 10,
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
"tts-1-hd-1106": 15,
"davinci": 10, "davinci": 10,
"curie": 10, "curie": 10,
"babbage": 10, "babbage": 10,
"ada": 10, "ada": 10,
"text-embedding-ada-002": 0.05, "text-embedding-ada-002": 0.05,
"text-embedding-3-small": 0.01,
"text-embedding-3-large": 0.065,
"text-search-ada-doc-001": 10, "text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1, "text-moderation-stable": 0.1,
"text-moderation-latest": 0.1, "text-moderation-latest": 0.1,
"dall-e-2": 8, // $0.016 - $0.020 / image "dall-e": 8,
"dall-e-3": 20, // $0.040 - $0.120 / image "claude-instant-1": 0.75,
"claude-instant-1": 0.815, // $1.63 / 1M tokens "claude-2": 30,
"claude-2": 5.51, // $11.02 / 1M tokens "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"claude-2.0": 5.51, // $11.02 / 1M tokens "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"claude-2.1": 5.51, // $11.02 / 1M tokens "PaLM-2": 1,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens
"embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens
"embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
"hunyuan": 7.143, // ¥0.1 / 1k tokens // https://cloud.tencent.com/document/product/1729/97731#e0e6be58-60c8-469f-bdeb-6c264ce3b4d0
"ChatStd": 0.01 * RMB,
"ChatPro": 0.1 * RMB,
// https://platform.moonshot.cn/pricing
"moonshot-v1-8k": 0.012 * RMB,
"moonshot-v1-32k": 0.024 * RMB,
"moonshot-v1-128k": 0.06 * RMB,
} }
func ModelRatio2JSONString() string { func ModelRatio2JSONString() string {
jsonBytes, err := json.Marshal(ModelRatio) jsonBytes, err := json.Marshal(ModelRatio)
if err != nil { if err != nil {
logger.SysError("error marshalling model ratio: " + err.Error()) SysError("error marshalling model ratio: " + err.Error())
} }
return string(jsonBytes) return string(jsonBytes)
} }
@@ -139,67 +62,10 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
} }
func GetModelRatio(name string) float64 { func GetModelRatio(name string) float64 {
if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") {
name = strings.TrimSuffix(name, "-internet")
}
ratio, ok := ModelRatio[name] ratio, ok := ModelRatio[name]
if !ok { if !ok {
logger.SysError("model ratio not found: " + name) SysError("model ratio not found: " + name)
return 30 return 30
} }
return ratio return ratio
} }
var CompletionRatio = map[string]float64{}
func CompletionRatio2JSONString() string {
jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil {
logger.SysError("error marshalling completion ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateCompletionRatioByJSONString(jsonStr string) error {
CompletionRatio = make(map[string]float64)
return json.Unmarshal([]byte(jsonStr), &CompletionRatio)
}
func GetCompletionRatio(name string) float64 {
if ratio, ok := CompletionRatio[name]; ok {
return ratio
}
if strings.HasPrefix(name, "gpt-3.5") {
if strings.HasSuffix(name, "0125") {
// https://openai.com/blog/new-embedding-models-and-api-updates
// Updated GPT-3.5 Turbo model and lower pricing
return 3
}
if strings.HasSuffix(name, "1106") {
return 2
}
if name == "gpt-3.5-turbo" || name == "gpt-3.5-turbo-16k" {
// TODO: clear this after 2023-12-11
now := time.Now()
// https://platform.openai.com/docs/models/continuous-model-upgrades
// if after 2023-12-11, use 2
if now.After(time.Date(2023, 12, 11, 0, 0, 0, 0, time.UTC)) {
return 2
}
}
return 1.333333
}
if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") {
return 3
}
return 2
}
if strings.HasPrefix(name, "claude-instant-1") {
return 3.38
}
if strings.HasPrefix(name, "claude-2") {
return 2.965517
}
return 1
}

View File

@@ -3,7 +3,6 @@ package common
import ( import (
"context" "context"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/songquanpeng/one-api/common/logger"
"os" "os"
"time" "time"
) )
@@ -15,18 +14,18 @@ var RedisEnabled = true
func InitRedisClient() (err error) { func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" { if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false RedisEnabled = false
logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil return nil
} }
if os.Getenv("SYNC_FREQUENCY") == "" { if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false RedisEnabled = false
logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil return nil
} }
logger.SysLog("Redis is enabled") SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil { if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error()) FatalLog("failed to parse Redis connection string: " + err.Error())
} }
RDB = redis.NewClient(opt) RDB = redis.NewClient(opt)
@@ -35,7 +34,7 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result() _, err = RDB.Ping(ctx).Result()
if err != nil { if err != nil {
logger.FatalLog("Redis ping test failed: " + err.Error()) FatalLog("Redis ping test failed: " + err.Error())
} }
return err return err
} }
@@ -43,7 +42,7 @@ func InitRedisClient() (err error) {
func ParseRedisOption() *redis.Options { func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
if err != nil { if err != nil {
logger.FatalLog("failed to parse Redis connection string: " + err.Error()) FatalLog("failed to parse Redis connection string: " + err.Error())
} }
return opt return opt
} }
@@ -62,8 +61,3 @@ func RedisDel(key string) error {
ctx := context.Background() ctx := context.Background()
return RDB.Del(ctx, key).Err() return RDB.Del(ctx, key).Err()
} }
func RedisDecrease(key string, value int64) error {
ctx := context.Background()
return RDB.DecrBy(ctx, key, value).Err()
}

View File

@@ -2,13 +2,178 @@ package common
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/config" "github.com/google/uuid"
"html/template"
"log"
"math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
) )
func LogQuota(quota int) string { func OpenBrowser(url string) {
if config.DisplayInCurrencyEnabled { var err error
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else { switch runtime.GOOS {
return fmt.Sprintf("%d 点额度", quota) case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
log.Println(err)
}
}
func GetIp() (ip string) {
ips, err := net.InterfaceAddrs()
if err != nil {
log.Println(err)
return ip
}
for _, a := range ips {
if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ip = ipNet.IP.String()
if strings.HasPrefix(ip, "10") {
return
}
if strings.HasPrefix(ip, "172") {
return
}
if strings.HasPrefix(ip, "192.168") {
return
}
ip = ""
}
}
}
return
}
var sizeKB = 1024
var sizeMB = sizeKB * 1024
var sizeGB = sizeMB * 1024
func Bytes2Size(num int64) string {
numStr := ""
unit := "B"
if num/int64(sizeGB) > 1 {
numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB))
unit = "GB"
} else if num/int64(sizeMB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB)))
unit = "MB"
} else if num/int64(sizeKB) > 1 {
numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB)))
unit = "KB"
} else {
numStr = fmt.Sprintf("%d", num)
}
return numStr + " " + unit
}
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string {
switch inter.(type) {
case string:
return inter.(string)
case int:
return fmt.Sprintf("%d", inter.(int))
case float64:
return fmt.Sprintf("%f", inter.(float64))
}
return "Not Implemented"
}
func UnescapeHTML(x string) interface{} {
return template.HTML(x)
}
func IntMax(a int, b int) int {
if a >= b {
return a
} else {
return b
}
}
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func Max(a int, b int) int {
if a >= b {
return a
} else {
return b
} }
} }

View File

@@ -2,9 +2,8 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "one-api/common"
"github.com/songquanpeng/one-api/model" "one-api/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
func GetSubscription(c *gin.Context) { func GetSubscription(c *gin.Context) {
@@ -13,7 +12,7 @@ func GetSubscription(c *gin.Context) {
var err error var err error
var token *model.Token var token *model.Token
var expiredTime int64 var expiredTime int64
if config.DisplayTokenStatEnabled { if common.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime expiredTime = token.ExpiredTime
@@ -22,27 +21,25 @@ func GetSubscription(c *gin.Context) {
} else { } else {
userId := c.GetInt("id") userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId) remainQuota, err = model.GetUserQuota(userId)
if err != nil { usedQuota, err = model.GetUserUsedQuota(userId)
usedQuota, err = model.GetUserUsedQuota(userId)
}
} }
if expiredTime <= 0 { if expiredTime <= 0 {
expiredTime = 0 expiredTime = 0
} }
if err != nil { if err != nil {
Error := relaymodel.Error{ openAIError := OpenAIError{
Message: err.Error(), Message: err.Error(),
Type: "upstream_error", Type: "one_api_error",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": Error, "error": openAIError,
}) })
return return
} }
quota := remainQuota + usedQuota quota := remainQuota + usedQuota
amount := float64(quota) amount := float64(quota)
if config.DisplayInCurrencyEnabled { if common.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit amount /= common.QuotaPerUnit
} }
if token != nil && token.UnlimitedQuota { if token != nil && token.UnlimitedQuota {
amount = 100000000 amount = 100000000
@@ -63,7 +60,7 @@ func GetUsage(c *gin.Context) {
var quota int var quota int
var err error var err error
var token *model.Token var token *model.Token
if config.DisplayTokenStatEnabled { if common.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id") tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId) token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota quota = token.UsedQuota
@@ -72,18 +69,18 @@ func GetUsage(c *gin.Context) {
quota, err = model.GetUserUsedQuota(userId) quota, err = model.GetUserUsedQuota(userId)
} }
if err != nil { if err != nil {
Error := relaymodel.Error{ openAIError := OpenAIError{
Message: err.Error(), Message: err.Error(),
Type: "one_api_error", Type: "one_api_error",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": Error, "error": openAIError,
}) })
return return
} }
amount := float64(quota) amount := float64(quota)
if config.DisplayInCurrencyEnabled { if common.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit amount /= common.QuotaPerUnit
} }
usage := OpenAIUsageResponse{ usage := OpenAIUsageResponse{
Object: "list", Object: "list",

View File

@@ -4,13 +4,10 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
@@ -95,7 +92,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := util.HTTPClient.Do(req) res, err := httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -114,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
} }
func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) {
url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL)
body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
if err != nil { if err != nil {
@@ -204,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := common.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.BaseURL == "" {
channel.BaseURL = &baseURL channel.BaseURL = baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case common.ChannelTypeOpenAI:
if channel.GetBaseURL() != "" { if channel.BaseURL != "" {
baseURL = channel.GetBaseURL() baseURL = channel.BaseURL
} }
case common.ChannelTypeAzure: case common.ChannelTypeAzure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case common.ChannelTypeCustom:
baseURL = channel.GetBaseURL() baseURL = channel.BaseURL
case common.ChannelTypeCloseAI: case common.ChannelTypeCloseAI:
return updateChannelCloseAIBalance(channel) return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB: case common.ChannelTypeOpenAISB:
@@ -316,7 +313,7 @@ func updateAllChannelsBalance() error {
disableChannel(channel.Id, channel.Name, "余额不足") disableChannel(channel.Id, channel.Name, "余额不足")
} }
} }
time.Sleep(config.RequestInterval) time.Sleep(common.RequestInterval)
} }
return nil return nil
} }
@@ -341,8 +338,8 @@ func UpdateAllChannelsBalance(c *gin.Context) {
func AutomaticallyUpdateChannels(frequency int) { func AutomaticallyUpdateChannels(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("updating all channels") common.SysLog("updating all channels")
_ = updateAllChannelsBalance() _ = updateAllChannelsBalance()
logger.SysLog("channels update done") common.SysLog("channels update done")
} }
} }

View File

@@ -1,36 +1,148 @@
package controller package controller
import ( import (
"bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http" "net/http"
"net/http/httptest" "one-api/common"
"net/url" "one-api/model"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func buildTestRequest() *relaymodel.GeneralOpenAIRequest { func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
testRequest := &relaymodel.GeneralOpenAIRequest{ switch channel.Type {
MaxTokens: 1, case common.ChannelTypePaLM:
Stream: false, fallthrough
Model: "gpt-3.5-turbo", case common.ChannelTypeAnthropic:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
case common.ChannelTypeAzure:
request.Model = "gpt-35-turbo"
default:
request.Model = "gpt-3.5-turbo"
} }
testMessage := relaymodel.Message{ requestURL := common.ChannelBaseURLs[channel.Type]
if channel.Type == common.ChannelTypeAzure {
requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
} else {
if channel.BaseURL != "" {
requestURL = channel.BaseURL
}
requestURL += "/v1/chat/completions"
}
jsonData, err := json.Marshal(request)
if err != nil {
return err, nil
}
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
if err != nil {
return err, nil
}
if channel.Type == common.ChannelTypeAzure {
req.Header.Set("api-key", channel.Key)
} else {
req.Header.Set("Authorization", "Bearer "+channel.Key)
}
req.Header.Set("Content-Type", "application/json")
resp, err := httpClient.Do(req)
if err != nil {
return err, nil
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errors.New(fmt.Sprintf("status code %d", resp.StatusCode)), nil
}
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
if channel.AllowStreaming == common.ChannelAllowStreamEnabled && isStream {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
// ChatGPT Next Web
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
}
if !strings.HasPrefix(data, "data:") {
continue
}
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
return err, nil
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
}
}
if responseText == "" {
return errors.New("Empty response"), nil
}
} else if channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled {
var response TextResponse
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
return err, nil
}
if response.Usage.CompletionTokens == 0 {
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
}
}
return nil, nil
}
func buildTestRequest(stream bool) *ChatRequest {
testRequest := &ChatRequest{
Model: "", // this will be set later
MaxTokens: 1,
Stream: stream,
}
testMessage := Message{
Role: "user", Role: "user",
Content: "hi", Content: "hi",
} }
@@ -38,65 +150,6 @@ func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
return testRequest return testRequest
} }
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: "POST",
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
if err != nil {
return err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
}
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
}
if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
}
func TestChannel(c *gin.Context) { func TestChannel(c *gin.Context) {
id, err := strconv.Atoi(c.Param("id")) id, err := strconv.Atoi(c.Param("id"))
if err != nil { if err != nil {
@@ -114,8 +167,9 @@ func TestChannel(c *gin.Context) {
}) })
return return
} }
testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
tik := time.Now() tik := time.Now()
err, _ = testChannel(channel) err, _ = testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
go channel.UpdateResponseTime(milliseconds) go channel.UpdateResponseTime(milliseconds)
@@ -139,35 +193,23 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false var testAllChannelsRunning bool = false
func notifyRootUser(subject string, content string) {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
err := common.SendEmail(subject, config.RootUserEmail, content)
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
// disable & notify // disable & notify
func disableChannel(channelId int, channelName string, reason string) { func disableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) if common.RootUserEmail == "" {
common.RootUserEmail = model.GetRootUserEmail()
}
model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId) subject := fmt.Sprintf("通道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason) content := fmt.Sprintf("通道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content) err := common.SendEmail(subject, common.RootUserEmail, content)
} if err != nil {
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
// enable & notify }
func enableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
subject := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("通道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
} }
func testAllChannels(notify bool) error { func testAllChannels(notify bool) error {
if config.RootUserEmail == "" { if common.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail() common.RootUserEmail = model.GetRootUserEmail()
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
if testAllChannelsRunning { if testAllChannelsRunning {
@@ -180,37 +222,37 @@ func testAllChannels(notify bool) error {
if err != nil { if err != nil {
return err return err
} }
var disableThreshold = int64(config.ChannelDisableThreshold * 1000) var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
if disableThreshold == 0 { if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value disableThreshold = 10000000 // a impossible value
} }
go func() { go func() {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled if channel.Status != common.ChannelStatusEnabled {
continue
}
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) testRequest := buildTestRequest(channel.AllowStreaming == common.ChannelAllowStreamEnabled)
err, openaiErr := testChannel(channel, *testRequest)
tok := time.Now() tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds() milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold { if milliseconds > disableThreshold {
err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
disableChannel(channel.Id, channel.Name, err.Error()) disableChannel(channel.Id, channel.Name, err.Error())
} }
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { if shouldDisableChannel(openaiErr) {
disableChannel(channel.Id, channel.Name, err.Error()) disableChannel(channel.Id, channel.Name, err.Error())
} }
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
enableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)
time.Sleep(config.RequestInterval) time.Sleep(common.RequestInterval)
} }
testAllChannelsLock.Lock() testAllChannelsLock.Lock()
testAllChannelsRunning = false testAllChannelsRunning = false
testAllChannelsLock.Unlock() testAllChannelsLock.Unlock()
if notify { if notify {
err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
if err != nil { if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
} }
} }
}() }()
@@ -236,8 +278,8 @@ func TestAllChannels(c *gin.Context) {
func AutomaticallyTestChannels(frequency int) { func AutomaticallyTestChannels(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Minute) time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("testing all channels") common.SysLog("testing all channels")
_ = testAllChannels(false) _ = testAllChannels(false)
logger.SysLog("channel test finished") common.SysLog("channel test finished")
} }
} }

View File

@@ -1,13 +1,13 @@
package controller package controller
import ( import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func GetAllChannels(c *gin.Context) { func GetAllChannels(c *gin.Context) {
@@ -15,7 +15,7 @@ func GetAllChannels(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -84,9 +84,9 @@ func AddChannel(c *gin.Context) {
}) })
return return
} }
channel.CreatedTime = helper.GetTimestamp() channel.CreatedTime = common.GetTimestamp()
keys := strings.Split(channel.Key, "\n") keys := strings.Split(channel.Key, "\n")
channels := make([]model.Channel, 0, len(keys)) channels := make([]model.Channel, 0)
for _, key := range keys { for _, key := range keys {
if key == "" { if key == "" {
continue continue
@@ -128,23 +128,6 @@ func DeleteChannel(c *gin.Context) {
return return
} }
func DeleteDisabledChannel(c *gin.Context) {
rows, err := model.DeleteDisabledChannel()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": rows,
})
return
}
func UpdateChannel(c *gin.Context) { func UpdateChannel(c *gin.Context) {
channel := model.Channel{} channel := model.Channel{}
err := c.ShouldBindJSON(&channel) err := c.ShouldBindJSON(&channel)

223
controller/discord.go Normal file
View File

@@ -0,0 +1,223 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type DiscordOAuthResponse struct {
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
TokenType string `json:"token_type"`
}
type DiscordUser struct {
Id string `json:"id"`
Username string `json:"username"`
}
func getDiscordUserInfoByCode(codeFromURLParamaters string, host string) (*DiscordUser, error) {
if codeFromURLParamaters == "" {
return nil, errors.New("无效参数")
}
RequestClient := &http.Client{}
accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf(
"client_id=%s&client_secret=%s&grant_type=authorization_code&redirect_uri=https://%s/oauth/discord&code=%s&scope=identify",
common.DiscordClientId, common.DiscordClientSecret, common.ServerAddress, codeFromURLParamaters,
)))
req, _ := http.NewRequest("POST",
"https://discordapp.com/api/oauth2/token",
accessTokenBody,
)
req.Header = http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded"},
"Accept": []string{"application/json"},
}
resp, err := RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("访问令牌无效")
}
var discordOAuthResponse DiscordOAuthResponse
json.NewDecoder(resp.Body).Decode(&discordOAuthResponse)
accessToken := fmt.Sprintf("Bearer %s", discordOAuthResponse.AccessToken)
// Get User Info
req, _ = http.NewRequest("GET", "https://discord.com/api/users/@me", nil)
req.Header = http.Header{
"Content-Type": []string{"application/json"},
"Authorization": []string{accessToken},
}
defer resp.Body.Close()
resp, err = RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("Discord 用户信息无效")
}
var discordUser DiscordUser
json.NewDecoder(resp.Body).Decode(&discordUser)
if err != nil {
return nil, err
}
if discordUser.Id == "" {
return nil, errors.New("返回值无效,用户字段为空,请稍后再试!")
}
defer resp.Body.Close()
return &discordUser, nil
}
func DiscordOAuth(c *gin.Context) {
session := sessions.Default(c)
username := session.Get("username")
if username != nil {
DiscordBind(c)
return
}
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
err := user.FillUserByDiscordId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "discord_" + strconv.Itoa(model.GetMaxUserId()+1)
if discordUser.Username != "" {
user.DisplayName = discordUser.Username
} else {
user.DisplayName = "Discord User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func DiscordBind(c *gin.Context) {
if !common.DiscordOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Discord 登录以及注册",
})
return
}
code := c.Query("code")
discordUser, err := getDiscordUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
DiscordId: discordUser.Id,
}
if model.IsDiscordIdAlreadyTaken(user.DiscordId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Discord 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.DiscordId = discordUser.Id
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -7,12 +7,9 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
) )
@@ -33,7 +30,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
if code == "" { if code == "" {
return nil, errors.New("无效的参数") return nil, errors.New("无效的参数")
} }
values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code}
jsonData, err := json.Marshal(values) jsonData, err := json.Marshal(values)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -49,7 +46,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
} }
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
logger.SysLog(err.Error()) common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
} }
defer res.Body.Close() defer res.Body.Close()
@@ -65,7 +62,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req) res2, err := client.Do(req)
if err != nil { if err != nil {
logger.SysLog(err.Error()) common.SysLog(err.Error())
return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!")
} }
defer res2.Body.Close() defer res2.Body.Close()
@@ -82,21 +79,13 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) {
func GitHubOAuth(c *gin.Context) { func GitHubOAuth(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username") username := session.Get("username")
if username != nil { if username != nil {
GitHubBind(c) GitHubBind(c)
return return
} }
if !config.GitHubOAuthEnabled { if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@@ -125,7 +114,7 @@ func GitHubOAuth(c *gin.Context) {
return return
} }
} else { } else {
if config.RegisterEnabled { if common.RegisterEnabled {
user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1)
if githubUser.Name != "" { if githubUser.Name != "" {
user.DisplayName = githubUser.Name user.DisplayName = githubUser.Name
@@ -163,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
} }
func GitHubBind(c *gin.Context) { func GitHubBind(c *gin.Context) {
if !config.GitHubOAuthEnabled { if !common.GitHubOAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "管理员未开启通过 GitHub 登录以及注册", "message": "管理员未开启通过 GitHub 登录以及注册",
@@ -216,22 +205,3 @@ func GitHubBind(c *gin.Context) {
}) })
return return
} }
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
state := helper.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": state,
})
}

226
controller/google.go Normal file
View File

@@ -0,0 +1,226 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
type GoogleAccessTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
RefreshToken string `json:"refresh_token"`
}
type GoogleUser struct {
Sub string `json:"sub"`
Name string `json:"name"`
}
func getGoogleUserInfoByCode(codeFromURLParamaters string, host string) (*GoogleUser, error) {
if codeFromURLParamaters == "" {
return nil, errors.New("无效参数")
}
RequestClient := &http.Client{}
accessTokenBody := bytes.NewBuffer([]byte(fmt.Sprintf(
"code=%s&client_id=%s&client_secret=%s&redirect_uri=%s/oauth/google&grant_type=authorization_code",
codeFromURLParamaters, common.GoogleClientId, common.GoogleClientSecret, common.ServerAddress,
)))
req, _ := http.NewRequest("POST",
"https://oauth2.googleapis.com/token",
accessTokenBody,
)
req.Header = http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded"},
"Accept": []string{"application/json"},
}
resp, err := RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("访问令牌无效")
}
var googleTokenResponse GoogleAccessTokenResponse
json.NewDecoder(resp.Body).Decode(&googleTokenResponse)
accessToken := "Bearer " + googleTokenResponse.AccessToken
// Get User Info
req, _ = http.NewRequest("GET", "https://www.googleapis.com/oauth2/v3/userinfo", nil)
req.Header = http.Header{
"Content-Type": []string{"application/json"},
"Authorization": []string{accessToken},
}
defer resp.Body.Close()
resp, err = RequestClient.Do(req)
if resp.StatusCode != 200 || err != nil {
return nil, errors.New("Google 用户信息无效")
}
var googleUser GoogleUser
// Parse json to googleUser
err = json.NewDecoder(resp.Body).Decode(&googleUser)
if err != nil {
return nil, err
}
if googleUser.Sub == "" {
return nil, errors.New("返回值无效,用户字段为空,请稍后再试!")
}
defer resp.Body.Close()
return &googleUser, nil
}
func GoogleOAuth(c *gin.Context) {
session := sessions.Default(c)
username := session.Get("username")
if username != nil {
GoogleBind(c)
return
}
if !common.GoogleOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Google 登录以及注册",
})
return
}
code := c.Query("code")
googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
GoogleId: googleUser.Sub,
}
if model.IsGoogleIdAlreadyTaken(user.GoogleId) {
err := user.FillUserByGoogleId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if common.RegisterEnabled {
user.Username = "google_" + strconv.Itoa(model.GetMaxUserId()+1)
if googleUser.Name != "" {
user.DisplayName = googleUser.Name
} else {
user.DisplayName = "Google User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
setupLogin(&user, c)
}
func GoogleBind(c *gin.Context) {
if !common.GoogleOAuthEnabled {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员未开启通过 Google 登录以及注册",
})
return
}
code := c.Query("code")
googleUser, err := getGoogleUserInfoByCode(code, c.Request.Host)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
GoogleId: googleUser.Sub,
}
if model.IsGoogleIdAlreadyTaken(user.GoogleId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该 Google 账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.GoogleId = googleUser.Sub
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -2,13 +2,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"net/http" "net/http"
"one-api/common"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName := range common.GroupRatio { for groupName, _ := range common.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -2,9 +2,8 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "one-api/common"
"github.com/songquanpeng/one-api/model" "one-api/model"
"net/http"
"strconv" "strconv"
) )
@@ -19,21 +18,19 @@ func GetAllLogs(c *gin.Context) {
username := c.Query("username") username := c.Query("username")
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func GetUserLogs(c *gin.Context) { func GetUserLogs(c *gin.Context) {
@@ -47,38 +44,36 @@ func GetUserLogs(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func SearchAllLogs(c *gin.Context) { func SearchAllLogs(c *gin.Context) {
keyword := c.Query("keyword") keyword := c.Query("keyword")
logs, err := model.SearchAllLogs(keyword) logs, err := model.SearchAllLogs(keyword)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func SearchUserLogs(c *gin.Context) { func SearchUserLogs(c *gin.Context) {
@@ -86,18 +81,17 @@ func SearchUserLogs(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
logs, err := model.SearchUserLogs(userId, keyword) logs, err := model.SearchUserLogs(userId, keyword)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
}) })
return return
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": logs, "data": logs,
}) })
return
} }
func GetLogsStat(c *gin.Context) { func GetLogsStat(c *gin.Context) {
@@ -107,10 +101,9 @@ func GetLogsStat(c *gin.Context) {
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
username := c.Query("username") username := c.Query("username")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "") //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, "")
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
@@ -118,7 +111,6 @@ func GetLogsStat(c *gin.Context) {
//"token": tokenNum, //"token": tokenNum,
}, },
}) })
return
} }
func GetLogsSelfStat(c *gin.Context) { func GetLogsSelfStat(c *gin.Context) {
@@ -128,10 +120,9 @@ func GetLogsSelfStat(c *gin.Context) {
endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64)
tokenName := c.Query("token_name") tokenName := c.Query("token_name")
modelName := c.Query("model_name") modelName := c.Query("model_name")
channel, _ := strconv.Atoi(c.Query("channel")) quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
quotaNum := model.SumUsedQuota(logType, startTimestamp, endTimestamp, modelName, username, tokenName, channel)
//tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName) //tokenNum := model.SumUsedToken(logType, startTimestamp, endTimestamp, modelName, username, tokenName)
c.JSON(http.StatusOK, gin.H{ c.JSON(200, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": gin.H{ "data": gin.H{
@@ -139,30 +130,4 @@ func GetLogsSelfStat(c *gin.Context) {
//"token": tokenNum, //"token": tokenNum,
}, },
}) })
return
}
func DeleteHistoryLogs(c *gin.Context) {
targetTimestamp, _ := strconv.ParseInt(c.Query("target_timestamp"), 10, 64)
if targetTimestamp == 0 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "target timestamp is required",
})
return
}
count, err := model.DeleteOldLog(targetTimestamp)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": count,
})
return
} }

View File

@@ -3,11 +3,9 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strings" "one-api/common"
"one-api/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -19,55 +17,59 @@ func GetStatus(c *gin.Context) {
"data": gin.H{ "data": gin.H{
"version": common.Version, "version": common.Version,
"start_time": common.StartTime, "start_time": common.StartTime,
"email_verification": config.EmailVerificationEnabled, "email_verification": common.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled, "github_oauth": common.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId, "github_client_id": common.GitHubClientId,
"system_name": config.SystemName, "discord_oauth": common.DiscordOAuthEnabled,
"logo": config.Logo, "discord_client_id": common.DiscordClientId,
"footer_html": config.Footer, "google_oauth": common.GoogleOAuthEnabled,
"wechat_qrcode": config.WeChatAccountQRCodeImageURL, "google_client_id": common.GoogleClientId,
"wechat_login": config.WeChatAuthEnabled, "system_name": common.SystemName,
"server_address": config.ServerAddress, "logo": common.Logo,
"turnstile_check": config.TurnstileCheckEnabled, "footer_html": common.Footer,
"turnstile_site_key": config.TurnstileSiteKey, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"top_up_link": config.TopUpLink, "wechat_login": common.WeChatAuthEnabled,
"chat_link": config.ChatLink, "server_address": common.ServerAddress,
"quota_per_unit": config.QuotaPerUnit, "turnstile_check": common.TurnstileCheckEnabled,
"display_in_currency": config.DisplayInCurrencyEnabled, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink,
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
}, },
}) })
return return
} }
func GetNotice(c *gin.Context) { func GetNotice(c *gin.Context) {
config.OptionMapRWMutex.RLock() common.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock() defer common.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": config.OptionMap["Notice"], "data": common.OptionMap["Notice"],
}) })
return return
} }
func GetAbout(c *gin.Context) { func GetAbout(c *gin.Context) {
config.OptionMapRWMutex.RLock() common.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock() defer common.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": config.OptionMap["About"], "data": common.OptionMap["About"],
}) })
return return
} }
func GetHomePageContent(c *gin.Context) { func GetHomePageContent(c *gin.Context) {
config.OptionMapRWMutex.RLock() common.OptionMapRWMutex.RLock()
defer config.OptionMapRWMutex.RUnlock() defer common.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
"data": config.OptionMap["HomePageContent"], "data": common.OptionMap["HomePageContent"],
}) })
return return
} }
@@ -81,22 +83,6 @@ func SendEmailVerification(c *gin.Context) {
}) })
return return
} }
if config.EmailDomainRestrictionEnabled {
allowed := false
for _, domain := range config.EmailDomainWhitelist {
if strings.HasSuffix(email, "@"+domain) {
allowed = true
break
}
}
if !allowed {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员启用了邮箱域名白名单,您的邮箱地址的域名不在白名单中",
})
return
}
}
if model.IsEmailAlreadyTaken(email) { if model.IsEmailAlreadyTaken(email) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -106,10 +92,10 @@ func SendEmailVerification(c *gin.Context) {
} }
code := common.GenerateVerificationCode(6) code := common.GenerateVerificationCode(6)
common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose)
subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s邮箱验证。</p>"+
"<p>您的验证码为: <strong>%s</strong></p>"+ "<p>您的验证码为: <strong>%s</strong></p>"+
"<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, code, common.VerificationValidMinutes) "<p>验证码 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, code, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -143,12 +129,12 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", config.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+
"<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+ "<p>如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s </p>"+
"<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", config.SystemName, link, link, common.VerificationValidMinutes) "<p>重置链接 %d 分钟内有效,如果不是本人操作,请忽略。</p>", common.SystemName, link, link, common.VerificationValidMinutes)
err := common.SendEmail(subject, email, content) err := common.SendEmail(subject, email, content)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -2,12 +2,8 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@@ -57,43 +53,277 @@ func init() {
IsBlocking: false, IsBlocking: false,
}) })
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ { openAIModels = []OpenAIModels{
adaptor := helper.GetAdaptor(i) {
channelName := adaptor.GetChannelName() Id: "dall-e",
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: channelName,
Permission: permission,
Root: modelName,
Parent: nil,
})
}
}
for _, modelName := range ai360.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1677649963,
OwnedBy: "360", OwnedBy: "openai",
Permission: permission, Permission: permission,
Root: modelName, Root: "dall-e",
Parent: nil, Parent: nil,
}) },
} {
for _, modelName := range moonshot.ModelList { Id: "gpt-3.5-turbo",
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1677649963,
OwnedBy: "moonshot", OwnedBy: "openai",
Permission: permission, Permission: permission,
Root: modelName, Root: "gpt-3.5-turbo",
Parent: nil, Parent: nil,
}) },
{
Id: "gpt-3.5-turbo-0301",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0301",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-0613",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k",
Parent: nil,
},
{
Id: "gpt-3.5-turbo-16k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-3.5-turbo-16k-0613",
Parent: nil,
},
{
Id: "gpt-4",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4",
Parent: nil,
},
{
Id: "gpt-4-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0314",
Parent: nil,
},
{
Id: "gpt-4-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-0613",
Parent: nil,
},
{
Id: "gpt-4-32k",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k",
Parent: nil,
},
{
Id: "gpt-4-32k-0314",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0314",
Parent: nil,
},
{
Id: "gpt-4-32k-0613",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "gpt-4-32k-0613",
Parent: nil,
},
{
Id: "text-embedding-ada-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-embedding-ada-002",
Parent: nil,
},
{
Id: "text-davinci-003",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-003",
Parent: nil,
},
{
Id: "text-davinci-002",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-002",
Parent: nil,
},
{
Id: "text-curie-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-curie-001",
Parent: nil,
},
{
Id: "text-babbage-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-babbage-001",
Parent: nil,
},
{
Id: "text-ada-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-ada-001",
Parent: nil,
},
{
Id: "text-moderation-latest",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-latest",
Parent: nil,
},
{
Id: "text-moderation-stable",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-moderation-stable",
Parent: nil,
},
{
Id: "text-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "text-davinci-edit-001",
Parent: nil,
},
{
Id: "code-davinci-edit-001",
Object: "model",
Created: 1677649963,
OwnedBy: "openai",
Permission: permission,
Root: "code-davinci-edit-001",
Parent: nil,
},
{
Id: "claude-instant-1",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-instant-1",
Parent: nil,
},
{
Id: "claude-2",
Object: "model",
Created: 1677649963,
OwnedBy: "anturopic",
Permission: permission,
Root: "claude-2",
Parent: nil,
},
{
Id: "ERNIE-Bot",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot",
Parent: nil,
},
{
Id: "ERNIE-Bot-turbo",
Object: "model",
Created: 1677649963,
OwnedBy: "baidu",
Permission: permission,
Root: "ERNIE-Bot-turbo",
Parent: nil,
},
{
Id: "PaLM-2",
Object: "model",
Created: 1677649963,
OwnedBy: "google",
Permission: permission,
Root: "PaLM-2",
Parent: nil,
},
{
Id: "chatglm_pro",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_pro",
Parent: nil,
},
{
Id: "chatglm_std",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_std",
Parent: nil,
},
{
Id: "chatglm_lite",
Object: "model",
Created: 1677649963,
OwnedBy: "zhipu",
Permission: permission,
Root: "chatglm_lite",
Parent: nil,
},
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {
@@ -113,14 +343,14 @@ func RetrieveModel(c *gin.Context) {
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model) c.JSON(200, model)
} else { } else {
Error := relaymodel.Error{ openAIError := OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId), Message: fmt.Sprintf("The model '%s' does not exist", modelId),
Type: "invalid_request_error", Type: "invalid_request_error",
Param: "model", Param: "model",
Code: "model_not_found", Code: "model_not_found",
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"error": Error, "error": openAIError,
}) })
} }
} }

View File

@@ -2,10 +2,9 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -13,17 +12,17 @@ import (
func GetOptions(c *gin.Context) { func GetOptions(c *gin.Context) {
var options []*model.Option var options []*model.Option
config.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
for k, v := range config.OptionMap { for k, v := range common.OptionMap {
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
continue continue
} }
options = append(options, &model.Option{ options = append(options, &model.Option{
Key: k, Key: k,
Value: helper.Interface2String(v), Value: common.Interface2String(v),
}) })
} }
config.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,
"message": "", "message": "",
@@ -43,40 +42,40 @@ func UpdateOption(c *gin.Context) {
return return
} }
switch option.Key { switch option.Key {
case "Theme":
if !config.ValidThemes[option.Value] {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无效的主题",
})
return
}
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
if option.Value == "true" && config.GitHubClientId == "" { if option.Value == "true" && common.GitHubClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 GitHub OAuth请先填入 GitHub Client Id 以及 GitHub Client Secret", "message": "无法启用 GitHub OAuth请先填入 GitHub Client ID 以及 GitHub Client Secret",
}) })
return return
} }
case "EmailDomainRestrictionEnabled": case "DiscordOAuthEnabled":
if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { if option.Value == "true" && common.DiscordClientId == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用邮箱域名限制,请先填入限制的邮箱域名", "message": "无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret",
}) })
return return
} }
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
if option.Value == "true" && config.WeChatServerAddress == "" { if option.Value == "true" && common.WeChatServerAddress == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用微信登录,请先填入微信登录相关配置信息!", "message": "无法启用微信登录,请先填入微信登录相关配置信息!",
}) })
return return
} }
case "GoogleOAuthEnabled":
if option.Value == "true" && common.GoogleClientId == "" {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法启用 Google OAuth请先填入 Google Client ID 以及 Google Client Secret",
})
return
}
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
if option.Value == "true" && config.TurnstileSiteKey == "" { if option.Value == "true" && common.TurnstileSiteKey == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!",

View File

@@ -2,10 +2,9 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
) )
@@ -14,7 +13,7 @@ func GetAllRedemptions(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -106,12 +105,12 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := helper.GetUUID() key := common.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,
Key: key, Key: key,
CreatedTime: helper.GetTimestamp(), CreatedTime: common.GetTimestamp(),
Quota: redemption.Quota, Quota: redemption.Quota,
} }
err = cleanRedemption.Insert() err = cleanRedemption.Insert()

214
controller/relay-baidu.go Normal file
View File

@@ -0,0 +1,214 @@
package controller
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strings"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type BaiduTokenResponse struct {
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
SessionKey string `json:"session_key"`
AccessToken string `json:"access_token"`
Scope string `json:"scope"`
SessionSecret string `json:"session_secret"`
}
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type BaiduError struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage Usage `json:"usage"`
BaiduError
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
messages := make([]BaiduMessage, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, BaiduMessage{
Role: "user",
Content: message.Content,
})
messages = append(messages, BaiduMessage{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, BaiduMessage{
Role: message.Role,
Content: message.Content,
})
}
}
return &BaiduChatRequest{
Messages: messages,
Stream: request.Stream,
}
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse {
choice := OpenAITextResponseChoice{
Index: 0,
Message: Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
fullTextResponse := OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []OpenAITextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
choice.FinishReason = "stop"
response := ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse BaiduChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
usage.PromptTokens += baiduResponse.Usage.PromptTokens
usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
usage.TotalTokens += baiduResponse.Usage.TotalTokens
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func baiduHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,20 +1,44 @@
package anthropic package controller
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"one-api/common"
"strings" "strings"
) )
type ClaudeMetadata struct {
UserId string `json:"user_id"`
}
type ClaudeRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ClaudeError struct {
Type string `json:"type"`
Message string `json:"message"`
}
type ClaudeResponse struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error ClaudeError `json:"error"`
}
func stopReasonClaude2OpenAI(reason string) string { func stopReasonClaude2OpenAI(reason string) string {
switch reason { switch reason {
case "stop_sequence": case "stop_sequence":
@@ -26,8 +50,8 @@ func stopReasonClaude2OpenAI(reason string) string {
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func requestOpenAI2Claude(textRequest GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := Request{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", Prompt: "",
MaxTokensToSample: textRequest.MaxTokens, MaxTokensToSample: textRequest.MaxTokens,
@@ -46,9 +70,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
} else if message.Role == "assistant" { } else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content) prompt += fmt.Sprintf("\n\nAssistant: %s", message.Content)
} else if message.Role == "system" { } else if message.Role == "system" {
if prompt == "" { prompt += fmt.Sprintf("\n\nSystem: %s", message.Content)
prompt = message.StringContent()
}
} }
} }
prompt += "\n\nAssistant:" prompt += "\n\nAssistant:"
@@ -56,43 +78,40 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
return &claudeRequest return &claudeRequest
} }
func streamResponseClaude2OpenAI(claudeResponse *Response) *openai.ChatCompletionsStreamResponse { func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = claudeResponse.Completion choice.Delta.Content = claudeResponse.Completion
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason) choice.FinishReason = stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" { var response ChatCompletionsStreamResponse
choice.FinishReason = &finishReason
}
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response return &response
} }
func responseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse {
choice := openai.TextResponseChoice{ choice := OpenAITextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: Message{
Role: "assistant", Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "), Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil, Name: nil,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: common.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []OpenAITextResponseChoice{choice},
} }
return &fullTextResponse return &fullTextResponse
} }
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := helper.GetTimestamp() createdTime := common.GetTimestamp()
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
@@ -119,16 +138,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
} }
stopChan <- true stopChan <- true
}() }()
common.SetEventStreamHeaders(c) c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
// some implementations may add \r at the end of data // some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var claudeResponse Response var claudeResponse ClaudeResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
responseText += claudeResponse.Completion responseText += claudeResponse.Completion
@@ -137,7 +160,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
response.Created = createdTime response.Created = createdTime
jsonStr, err := json.Marshal(response) jsonStr, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
@@ -149,28 +172,28 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
} }
return nil, responseText return nil, responseText
} }
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func claudeHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
var claudeResponse Response var claudeResponse ClaudeResponse
err = json.Unmarshal(responseBody, &claudeResponse) err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if claudeResponse.Error.Type != "" { if claudeResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{ return &OpenAIErrorWithStatusCode{
Error: model.Error{ OpenAIError: OpenAIError{
Message: claudeResponse.Error.Message, Message: claudeResponse.Error.Message,
Type: claudeResponse.Error.Type, Type: claudeResponse.Error.Type,
Param: "", Param: "",
@@ -180,9 +203,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}, nil }, nil
} }
fullTextResponse := responseClaude2OpenAI(&claudeResponse) fullTextResponse := responseClaude2OpenAI(&claudeResponse)
fullTextResponse.Model = modelName completionTokens := countTokenText(claudeResponse.Completion, model)
completionTokens := openai.CountTokenText(claudeResponse.Completion, modelName) usage := Usage{
usage := model.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
@@ -190,7 +212,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

180
controller/relay-image.go Normal file
View File

@@ -0,0 +1,180 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
imageModel := "dall-e"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var imageRequest ImageRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &imageRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
// Prompt validation
if imageRequest.Prompt == "" {
return errorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest)
}
// Not "256x256", "512x512", or "1024x1024"
if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" {
return errorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024"), "invalid_field_value", http.StatusBadRequest)
}
// N should between 1 and 10
if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
return errorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
// Size
if imageRequest.Size == "256x256" {
sizeRatio = 1
} else if imageRequest.Size == "512x512" {
sizeRatio = 1.125
} else if imageRequest.Size == "1024x1024" {
sizeRatio = 1.25
}
quota := int(ratio*sizeRatio*1000) * imageRequest.N
if consumeQuota && userQuota-quota < 0 {
return errorWrapper(err, "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var textResponse ImageResponse
defer func() {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

153
controller/relay-openai.go Normal file
View File

@@ -0,0 +1,153 @@
package controller
import (
"bufio"
"bytes"
"encoding/json"
"io"
"net/http"
"one-api/common"
"strings"
"github.com/gin-gonic/gin"
)
func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
// ChatGPT Next Web
if strings.HasPrefix(data, "event:") || strings.Contains(data, "event:") {
// Remove event: event in the front or back
data = strings.TrimPrefix(data, "event: event")
data = strings.TrimSuffix(data, "event: event")
// Remove everything, only keep `data: {...}` <--- this is the json
// Find the start and end indices of `data: {...}` substring
startIndex := strings.Index(data, "data:")
endIndex := strings.LastIndex(data, "}")
// If both indices are found and end index is greater than start index
if startIndex != -1 && endIndex != -1 && endIndex > startIndex {
// Extract the `data: {...}` substring
data = data[startIndex : endIndex+1]
}
}
if !strings.HasPrefix(data, "data:") {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
case RelayModeChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
}
case RelayModeCompletions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
for _, choice := range streamResponse.Choices {
responseText += choice.Text
}
}
}
}
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
c.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool) (*OpenAIErrorWithStatusCode, *Usage) {
var textResponse TextResponse
if consumeQuota {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &textResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if textResponse.Error.Type != "" {
return &OpenAIErrorWithStatusCode{
OpenAIError: textResponse.Error,
StatusCode: resp.StatusCode,
}, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &textResponse.Usage
}

209
controller/relay-palm.go Normal file
View File

@@ -0,0 +1,209 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}
func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.Content,
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse {
fullTextResponse := OpenAITextResponse{
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := OpenAITextResponseChoice{
Index: i,
Message: Message{
Role: "assistant",
Content: candidate.Content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *ChatCompletionsStreamResponse {
var choice ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.Content = palmResponse.Candidates[0].Content
}
choice.FinishReason = "stop"
var response ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []ChatCompletionsStreamResponseChoice{choice}
return &response
}
func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) {
responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
err = resp.Body.Close()
if err != nil {
common.SysError("error closing stream response: " + err.Error())
stopChan <- true
return
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return &OpenAIErrorWithStatusCode{
OpenAIError: OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := countTokenText(palmResponse.Candidates[0].Content, model)
usage := Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

424
controller/relay-text.go Normal file
View File

@@ -0,0 +1,424 @@
package controller
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
const (
APITypeOpenAI = iota
APITypeClaude
APITypePaLM
APITypeBaidu
APITypeZhipu
)
var httpClient *http.Client
func init() {
httpClient = &http.Client{}
}
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
channelType := c.GetInt("channel")
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var textRequest GeneralOpenAIRequest
if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
err := common.UnmarshalBodyReusable(c, &textRequest)
if err != nil {
return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
}
if relayMode == RelayModeModerations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == RelayModeEmbeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
// request validation
if textRequest.Model == "" {
return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest)
}
switch relayMode {
case RelayModeCompletions:
if textRequest.Prompt == "" {
return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEmbeddings:
case RelayModeModerations:
if textRequest.Input == "" {
return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
}
case RelayModeEdits:
if textRequest.Instruction == "" {
return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest)
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
}
if modelMap[textRequest.Model] != "" {
textRequest.Model = modelMap[textRequest.Model]
isModelMapped = true
}
}
apiType := APITypeOpenAI
switch channelType {
case common.ChannelTypeAnthropic:
apiType = APITypeClaude
case common.ChannelTypeBaidu:
apiType = APITypeBaidu
case common.ChannelTypePaLM:
apiType = APITypePaLM
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString("api_version")
}
requestURL := strings.Split(requestURL, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
baseURL = c.GetString("base_url")
task := strings.TrimPrefix(requestURL, "/v1/")
model_ := textRequest.Model
model_ = strings.Replace(model_, ".", "", -1)
// https://github.com/songquanpeng/one-api/issues/67
model_ = strings.TrimSuffix(model_, "-0301")
model_ = strings.TrimSuffix(model_, "-0314")
model_ = strings.TrimSuffix(model_, "-0613")
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/%s", baseURL, model_, task)
}
case APITypeClaude:
fullRequestURL = "https://api.anthropic.com/v1/complete"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL)
}
case APITypeBaidu:
switch textRequest.Model {
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
case APITypePaLM:
fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
if baseURL != "" {
fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
fullRequestURL += "?key=" + apiKey
case APITypeZhipu:
method := "invoke"
if textRequest.Stream {
method = "sse-invoke"
}
fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
}
var promptTokens int
var completionTokens int
switch relayMode {
case RelayModeChatCompletions:
promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model)
case RelayModeCompletions:
promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model)
case RelayModeModerations:
promptTokens = countTokenInput(textRequest.Input, textRequest.Model)
}
preConsumedTokens := common.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
modelRatio := common.GetModelRatio(textRequest.Model)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 10*preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
}
if consumeQuota && preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
if err != nil {
return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
}
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
switch apiType {
case APITypeClaude:
claudeRequest := requestOpenAI2Claude(textRequest)
jsonStr, err := json.Marshal(claudeRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeBaidu:
baiduRequest := requestOpenAI2Baidu(textRequest)
jsonStr, err := json.Marshal(baiduRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypePaLM:
palmRequest := requestOpenAI2PaLM(textRequest)
jsonStr, err := json.Marshal(palmRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
case APITypeZhipu:
zhipuRequest := requestOpenAI2Zhipu(textRequest)
jsonStr, err := json.Marshal(zhipuRequest)
if err != nil {
return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
switch apiType {
case APITypeOpenAI:
if channelType == common.ChannelTypeAzure {
req.Header.Set("api-key", apiKey)
} else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
}
case APITypeClaude:
req.Header.Set("x-api-key", apiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
case APITypeZhipu:
token := getZhipuToken(apiKey)
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
//req.Header.Set("Connection", c.Request.Header.Get("Connection"))
resp, err := httpClient.Do(req)
if err != nil {
return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if resp.StatusCode != http.StatusOK {
return errorWrapper(nil, "bad_status_code", resp.StatusCode)
}
var textResponse TextResponse
isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
var streamResponseText string
defer func() {
if consumeQuota {
quota := 0
completionRatio := 1.0
if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
completionRatio = 1.333333
}
if strings.HasPrefix(textRequest.Model, "gpt-4") {
completionRatio = 2
}
if isStream && apiType != APITypeBaidu && apiType != APITypeZhipu {
completionTokens = countTokenText(streamResponseText, textRequest.Model)
} else {
promptTokens = textResponse.Usage.PromptTokens
completionTokens = textResponse.Usage.CompletionTokens
if apiType == APITypeZhipu {
// zhipu's API does not return prompt tokens & completion tokens
promptTokens = textResponse.Usage.TotalTokens
}
}
quota = promptTokens + int(float64(completionTokens)*completionRatio)
quota = int(float64(quota) * ratio)
if ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
switch apiType {
case APITypeOpenAI:
if isStream {
err, responseText := openaiStreamHandler(c, resp, relayMode)
if err != nil {
return err
}
streamResponseText = responseText
return nil
} else {
err, usage := openaiHandler(c, resp, consumeQuota)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeClaude:
if isStream {
err, responseText := claudeStreamHandler(c, resp)
if err != nil {
return err
}
streamResponseText = responseText
return nil
} else {
err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeBaidu:
if isStream {
err, usage := baiduStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := baiduHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypePaLM:
if textRequest.Stream { // PaLM2 API does not support stream
err, responseText := palmStreamHandler(c, resp)
if err != nil {
return err
}
streamResponseText = responseText
return nil
} else {
err, usage := palmHandler(c, resp, promptTokens, textRequest.Model)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
case APITypeZhipu:
if isStream {
err, usage := zhipuStreamHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
} else {
err, usage := zhipuHandler(c, resp)
if err != nil {
return err
}
if usage != nil {
textResponse.Usage = *usage
}
return nil
}
default:
return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
}
}

106
controller/relay-utils.go Normal file
View File

@@ -0,0 +1,106 @@
package controller
import (
"fmt"
"github.com/pkoukk/tiktoken-go"
"one-api/common"
)
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
func getTokenEncoder(model string) *tiktoken.Tiktoken {
if tokenEncoder, ok := tokenEncoderMap[model]; ok {
return tokenEncoder
}
tokenEncoder, err := tiktoken.EncodingForModel(model)
if err != nil {
common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error()))
tokenEncoder, err = tiktoken.EncodingForModel("gpt-3.5-turbo")
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get token encoder for model gpt-3.5-turbo: %s", err.Error()))
}
}
tokenEncoderMap[model] = tokenEncoder
return tokenEncoder
}
func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
if common.ApproximateTokenEnabled {
return int(float64(len(text)) * 0.38)
}
return len(tokenEncoder.Encode(text, nil, nil))
}
func countTokenMessages(messages []Message, model string) int {
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Content)
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum
}
func countTokenInput(input any, model string) int {
switch input.(type) {
case string:
return countTokenText(input.(string), model)
case []string:
text := ""
for _, s := range input.([]string) {
text += s
}
return countTokenText(text, model)
}
return 0
}
func countTokenText(text string, model string) int {
tokenEncoder := getTokenEncoder(model)
return getTokenNum(tokenEncoder, text)
}
func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatusCode {
openAIError := OpenAIError{
Message: err.Error(),
Type: "one_api_error",
Code: code,
}
return &OpenAIErrorWithStatusCode{
OpenAIError: openAIError,
StatusCode: statusCode,
}
}
func shouldDisableChannel(err *OpenAIError) bool {
if !common.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
return false
}

View File

@@ -1,18 +1,13 @@
package zhipu package controller
import ( import (
"bufio" "bufio"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt" "github.com/golang-jwt/jwt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
"one-api/common"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -23,13 +18,53 @@ import (
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/invoke
// https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke // https://open.bigmodel.cn/api/paas/v3/model-api/chatglm_std/sse-invoke
type ZhipuMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ZhipuRequest struct {
Prompt []ZhipuMessage `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
type ZhipuResponseData struct {
TaskId string `json:"task_id"`
RequestId string `json:"request_id"`
TaskStatus string `json:"task_status"`
Choices []ZhipuMessage `json:"choices"`
Usage `json:"usage"`
}
type ZhipuResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
Data ZhipuResponseData `json:"data"`
}
type ZhipuStreamMetaResponse struct {
RequestId string `json:"request_id"`
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
Usage `json:"usage"`
}
type zhipuTokenData struct {
Token string
ExpiryTime time.Time
}
var zhipuTokens sync.Map var zhipuTokens sync.Map
var expSeconds int64 = 24 * 3600 var expSeconds int64 = 24 * 3600
func GetToken(apikey string) string { func getZhipuToken(apikey string) string {
data, ok := zhipuTokens.Load(apikey) data, ok := zhipuTokens.Load(apikey)
if ok { if ok {
tokenData := data.(tokenData) tokenData := data.(zhipuTokenData)
if time.Now().Before(tokenData.ExpiryTime) { if time.Now().Before(tokenData.ExpiryTime) {
return tokenData.Token return tokenData.Token
} }
@@ -37,7 +72,7 @@ func GetToken(apikey string) string {
split := strings.Split(apikey, ".") split := strings.Split(apikey, ".")
if len(split) != 2 { if len(split) != 2 {
logger.SysError("invalid zhipu key: " + apikey) common.SysError("invalid zhipu key: " + apikey)
return "" return ""
} }
@@ -65,7 +100,7 @@ func GetToken(apikey string) string {
return "" return ""
} }
zhipuTokens.Store(apikey, tokenData{ zhipuTokens.Store(apikey, zhipuTokenData{
Token: tokenString, Token: tokenString,
ExpiryTime: expiryTime, ExpiryTime: expiryTime,
}) })
@@ -73,26 +108,26 @@ func GetToken(apikey string) string {
return tokenString return tokenString
} }
func ConvertRequest(request model.GeneralOpenAIRequest) *Request { func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest {
messages := make([]Message, 0, len(request.Messages)) messages := make([]ZhipuMessage, 0, len(request.Messages))
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
messages = append(messages, Message{ messages = append(messages, ZhipuMessage{
Role: "system", Role: "system",
Content: message.StringContent(), Content: message.Content,
}) })
messages = append(messages, Message{ messages = append(messages, ZhipuMessage{
Role: "user", Role: "user",
Content: "Okay", Content: "Okay",
}) })
} else { } else {
messages = append(messages, Message{ messages = append(messages, ZhipuMessage{
Role: message.Role, Role: message.Role,
Content: message.StringContent(), Content: message.Content,
}) })
} }
} }
return &Request{ return &ZhipuRequest{
Prompt: messages, Prompt: messages,
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
@@ -100,18 +135,18 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *Request {
} }
} }
func responseZhipu2OpenAI(response *Response) *openai.TextResponse { func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := OpenAITextResponse{
Id: response.Data.TaskId, Id: response.Data.TaskId,
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: common.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Data.Choices)), Choices: make([]OpenAITextResponseChoice, 0, len(response.Data.Choices)),
Usage: response.Data.Usage, Usage: response.Data.Usage,
} }
for i, choice := range response.Data.Choices { for i, choice := range response.Data.Choices {
openaiChoice := openai.TextResponseChoice{ openaiChoice := OpenAITextResponseChoice{
Index: i, Index: i,
Message: model.Message{ Message: Message{
Role: choice.Role, Role: choice.Role,
Content: strings.Trim(choice.Content, "\""), Content: strings.Trim(choice.Content, "\""),
}, },
@@ -125,41 +160,42 @@ func responseZhipu2OpenAI(response *Response) *openai.TextResponse {
return &fullTextResponse return &fullTextResponse
} }
func streamResponseZhipu2OpenAI(zhipuResponse string) *openai.ChatCompletionsStreamResponse { func streamResponseZhipu2OpenAI(zhipuResponse string) *ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = zhipuResponse choice.Delta.Content = zhipuResponse
response := openai.ChatCompletionsStreamResponse{ choice.FinishReason = ""
response := ChatCompletionsStreamResponse{
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: common.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []ChatCompletionsStreamResponseChoice{choice},
} }
return &response return &response
} }
func streamMetaResponseZhipu2OpenAI(zhipuResponse *StreamMetaResponse) (*openai.ChatCompletionsStreamResponse, *model.Usage) { func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*ChatCompletionsStreamResponse, *Usage) {
var choice openai.ChatCompletionsStreamResponseChoice var choice ChatCompletionsStreamResponseChoice
choice.Delta.Content = "" choice.Delta.Content = ""
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = "stop"
response := openai.ChatCompletionsStreamResponse{ response := ChatCompletionsStreamResponse{
Id: zhipuResponse.RequestId, Id: zhipuResponse.RequestId,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: common.GetTimestamp(),
Model: "chatglm", Model: "chatglm",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, Choices: []ChatCompletionsStreamResponseChoice{choice},
} }
return &response, &zhipuResponse.Usage return &response, &zhipuResponse.Usage
} }
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var usage *model.Usage var usage *Usage
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { if atEOF && len(data) == 0 {
return 0, nil, nil return 0, nil, nil
} }
if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 2, data[0:i], nil return i + 1, data[0:i], nil
} }
if atEOF { if atEOF {
return len(data), data, nil return len(data), data, nil
@@ -172,46 +208,45 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
lines := strings.Split(data, "\n") data = strings.Trim(data, "\"")
for i, line := range lines { if len(data) < 5 { // ignore blank line or wrong format
if len(line) < 5 { continue
continue }
} if data[:5] == "data:" {
if line[:5] == "data:" { dataChan <- data[5:]
dataChan <- line[5:] } else if data[:5] == "meta:" {
if i != len(lines)-1 { metaChan <- data[5:]
dataChan <- "\n"
}
} else if line[:5] == "meta:" {
metaChan <- line[5:]
}
} }
} }
stopChan <- true stopChan <- true
}() }()
common.SetEventStreamHeaders(c) c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Transfer-Encoding", "chunked")
c.Writer.Header().Set("X-Accel-Buffering", "no")
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
response := streamResponseZhipu2OpenAI(data) response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true return true
case data := <-metaChan: case data := <-metaChan:
var zhipuResponse StreamMetaResponse var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse) err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
return true return true
} }
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse) response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
if err != nil { if err != nil {
logger.SysError("error marshalling stream response: " + err.Error()) common.SysError("error marshalling stream response: " + err.Error())
return true return true
} }
usage = zhipuUsage usage = zhipuUsage
@@ -224,28 +259,28 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}) })
err := resp.Body.Close() err := resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
return nil, usage return nil, usage
} }
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func zhipuHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
var zhipuResponse Response var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &zhipuResponse) err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
if !zhipuResponse.Success { if !zhipuResponse.Success {
return &model.ErrorWithStatusCode{ return &OpenAIErrorWithStatusCode{
Error: model.Error{ OpenAIError: OpenAIError{
Message: zhipuResponse.Msg, Message: zhipuResponse.Msg,
Type: "zhipu_error", Type: "zhipu_error",
Param: "", Param: "",
@@ -255,10 +290,9 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
}, nil }, nil
} }
fullTextResponse := responseZhipu2OpenAI(&zhipuResponse) fullTextResponse := responseZhipu2OpenAI(&zhipuResponse)
fullTextResponse.Model = "chatglm"
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
} }
c.Writer.Header().Set("Content-Type", "application/json") c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)

View File

@@ -2,57 +2,177 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http" "net/http"
"one-api/common"
"strconv" "strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
Name *string `json:"name,omitempty"`
}
const (
RelayModeUnknown = iota
RelayModeChatCompletions
RelayModeCompletions
RelayModeEmbeddings
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
MaxTokens int `json:"max_tokens"`
Stream bool `json:"stream"`
}
type TextRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Prompt string `json:"prompt"`
MaxTokens int `json:"max_tokens"`
//Stream bool `json:"stream"`
}
type ImageRequest struct {
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type TextResponse struct {
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type OpenAITextResponseChoice struct {
Index int `json:"index"`
Message `json:"message"`
FinishReason string `json:"finish_reason"`
}
type OpenAITextResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
}
type ChatCompletionsStreamResponseChoice struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []ChatCompletionsStreamResponseChoice `json:"choices"`
}
type CompletionsStreamResponse struct {
Choices []struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := RelayModeUnknown
var err *model.ErrorWithStatusCode if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
relayMode = RelayModeChatCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
relayMode = RelayModeCompletions
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
relayMode = RelayModeEmbeddings
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
relayMode = RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
relayMode = RelayModeImagesGenerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
relayMode = RelayModeEdits
}
var err *OpenAIErrorWithStatusCode
switch relayMode { switch relayMode {
case constant.RelayModeImagesGenerations: case RelayModeImagesGenerations:
err = controller.RelayImageHelper(c, relayMode) err = relayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech:
fallthrough
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c) err = relayTextHelper(c, relayMode)
} }
if err != nil { if err != nil {
requestId := c.GetString(logger.RequestIdKey)
retryTimesStr := c.Query("retry") retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr) retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" { if retryTimesStr == "" {
retryTimes = config.RetryTimes retryTimes = common.RetryTimes
} }
if retryTimes > 0 { if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else { } else {
if err.StatusCode == http.StatusTooManyRequests { if err.StatusCode == http.StatusTooManyRequests {
err.Error.Message = "当前分组上游负载已饱和,请稍后再试" err.OpenAIError.Message = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
} }
err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId)
c.JSON(err.StatusCode, gin.H{ c.JSON(err.StatusCode, gin.H{
"error": err.Error, "error": err.OpenAIError,
}) })
} }
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) { if shouldDisableChannel(&err.OpenAIError) {
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
channelName := c.GetString("channel_name") channelName := c.GetString("channel_name")
disableChannel(channelId, channelName, err.Message) disableChannel(channelId, channelName, err.Message)
@@ -61,7 +181,7 @@ func Relay(c *gin.Context) {
} }
func RelayNotImplemented(c *gin.Context) { func RelayNotImplemented(c *gin.Context) {
err := model.Error{ err := OpenAIError{
Message: "API not implemented", Message: "API not implemented",
Type: "one_api_error", Type: "one_api_error",
Param: "", Param: "",
@@ -73,11 +193,11 @@ func RelayNotImplemented(c *gin.Context) {
} }
func RelayNotFound(c *gin.Context) { func RelayNotFound(c *gin.Context) {
err := model.Error{ err := OpenAIError{
Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Message: fmt.Sprintf("API not found: %s:%s", c.Request.Method, c.Request.URL.Path),
Type: "invalid_request_error", Type: "one_api_error",
Param: "", Param: "",
Code: "", Code: "api_not_found",
} }
c.JSON(http.StatusNotFound, gin.H{ c.JSON(http.StatusNotFound, gin.H{
"error": err, "error": err,

View File

@@ -2,11 +2,9 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
) )
@@ -16,7 +14,7 @@ func GetAllTokens(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -111,19 +109,19 @@ func AddToken(c *gin.Context) {
}) })
return return
} }
if len(token.Name) > 30 { if len(token.Name) == 0 || len(token.Name) > 20 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌名称长", "message": "令牌名称长度必须在1-20之间",
}) })
return return
} }
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: helper.GenerateKey(), Key: common.GenerateKey(),
CreatedTime: helper.GetTimestamp(), CreatedTime: common.GetTimestamp(),
AccessedTime: helper.GetTimestamp(), AccessedTime: common.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota, RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota, UnlimitedQuota: token.UnlimitedQuota,
@@ -173,13 +171,6 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if len(token.Name) > 30 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
})
return
}
cleanToken, err := model.GetTokenByIds(token.Id, userId) cleanToken, err := model.GetTokenByIds(token.Id, userId)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -189,7 +180,7 @@ func UpdateToken(c *gin.Context) {
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == common.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= common.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",

View File

@@ -3,13 +3,10 @@ package controller
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -21,7 +18,7 @@ type LoginRequest struct {
} }
func Login(c *gin.Context) { func Login(c *gin.Context) {
if !config.PasswordLoginEnabled { if !common.PasswordLoginEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了密码登录", "message": "管理员关闭了密码登录",
"success": false, "success": false,
@@ -108,14 +105,14 @@ func Logout(c *gin.Context) {
} }
func Register(c *gin.Context) { func Register(c *gin.Context) {
if !config.RegisterEnabled { if !common.RegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了新用户注册", "message": "管理员关闭了新用户注册",
"success": false, "success": false,
}) })
return return
} }
if !config.PasswordRegisterEnabled { if !common.PasswordRegisterEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册",
"success": false, "success": false,
@@ -138,7 +135,7 @@ func Register(c *gin.Context) {
}) })
return return
} }
if config.EmailVerificationEnabled { if common.EmailVerificationEnabled {
if user.Email == "" || user.VerificationCode == "" { if user.Email == "" || user.VerificationCode == "" {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -162,7 +159,7 @@ func Register(c *gin.Context) {
DisplayName: user.Username, DisplayName: user.Username,
InviterId: inviterId, InviterId: inviterId,
} }
if config.EmailVerificationEnabled { if common.EmailVerificationEnabled {
cleanUser.Email = user.Email cleanUser.Email = user.Email
} }
if err := cleanUser.Insert(inviterId); err != nil { if err := cleanUser.Insert(inviterId); err != nil {
@@ -184,7 +181,7 @@ func GetAllUsers(c *gin.Context) {
if p < 0 { if p < 0 {
p = 0 p = 0
} }
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -251,29 +248,6 @@ func GetUser(c *gin.Context) {
return return
} }
func GetUserDashboard(c *gin.Context) {
id := c.GetInt("id")
now := time.Now()
startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix()
endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix()
dashboards, err := model.SearchLogsByDayAndModel(id, int(startOfDay), int(endOfDay))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法获取统计信息",
"data": nil,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": dashboards,
})
return
}
func GenerateAccessToken(c *gin.Context) { func GenerateAccessToken(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, err := model.GetUserById(id, true) user, err := model.GetUserById(id, true)
@@ -284,7 +258,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = helper.GetUUID() user.AccessToken = common.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -321,7 +295,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = helper.GetRandomString(4) user.AffCode = common.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -728,7 +702,7 @@ func EmailBind(c *gin.Context) {
return return
} }
if user.Role == common.RoleRootUser { if user.Role == common.RoleRootUser {
config.RootUserEmail = email common.RootUserEmail = email
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": true, "success": true,

View File

@@ -5,10 +5,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"time" "time"
) )
@@ -23,11 +22,11 @@ func getWeChatIdByCode(code string) (string, error) {
if code == "" { if code == "" {
return "", errors.New("无效的参数") return "", errors.New("无效的参数")
} }
req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil)
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header.Set("Authorization", config.WeChatServerToken) req.Header.Set("Authorization", common.WeChatServerToken)
client := http.Client{ client := http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@@ -51,7 +50,7 @@ func getWeChatIdByCode(code string) (string, error) {
} }
func WeChatAuth(c *gin.Context) { func WeChatAuth(c *gin.Context) {
if !config.WeChatAuthEnabled { if !common.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,
@@ -80,7 +79,7 @@ func WeChatAuth(c *gin.Context) {
return return
} }
} else { } else {
if config.RegisterEnabled { if common.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = common.RoleCommonUser
@@ -113,7 +112,7 @@ func WeChatAuth(c *gin.Context) {
} }
func WeChatBind(c *gin.Context) { func WeChatBind(c *gin.Context) {
if !config.WeChatAuthEnabled { if !common.WeChatAuthEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "管理员未开启通过微信登录以及注册", "message": "管理员未开启通过微信登录以及注册",
"success": false, "success": false,

View File

@@ -2,28 +2,28 @@ version: '3.4'
services: services:
one-api: one-api:
image: justsong/one-api:latest image: ckt1031/one-api:latest
container_name: one-api container_name: one-api
restart: always restart: always
command: --log-dir /app/logs command: --log-dir /app/logs
ports: ports:
- "3000:3000" - "3000:3000"
volumes: volumes:
- ./data/oneapi:/data - ./data:/data
- ./logs:/app/logs - ./logs:/app/logs
environment: environment:
- SQL_DSN=oneapi:123456@tcp(db:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库 - SQL_DSN=root:123456@tcp(host.docker.internal:3306)/one-api # 修改此行,或注释掉以使用 SQLite 作为数据库
- REDIS_CONN_STRING=redis://redis - REDIS_CONN_STRING=redis://redis
- SESSION_SECRET=random_string # 修改为随机字符串 - SESSION_SECRET=random_string # 修改为随机字符串
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
# - NODE_TYPE=slave # 多机部署时从节点取消注释该行 # - NODE_TYPE=slave # 多机部署时从节点取消注释该行
# - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行 # - SYNC_FREQUENCY=60 # 需要定期从数据库加载数据时取消注释该行
# - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行 # - FRONTEND_BASE_URL=https://openai.justsong.cn # 多机部署时从节点取消注释该行
depends_on: depends_on:
- redis - redis
- db
healthcheck: healthcheck:
test: [ "CMD-SHELL", "wget -q -O - http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk -F: '{print $2}'" ] test: [ "CMD-SHELL", "curl -s http://localhost:3000/api/status | grep -o '\"success\":\\s*true' | awk '{print $2}' | grep 'true'" ]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3
@@ -32,18 +32,3 @@ services:
image: redis:latest image: redis:latest
container_name: redis container_name: redis
restart: always restart: always
db:
image: mysql:8.2.0
restart: always
container_name: mysql
volumes:
- ./data/mysql:/var/lib/mysql # 挂载目录,持久化存储
ports:
- '3306:3306'
environment:
TZ: Asia/Shanghai # 设置时区
MYSQL_ROOT_PASSWORD: 'OneAPI@justsong' # 设置 root 用户的密码
MYSQL_USER: oneapi # 创建专用用户
MYSQL_PASSWORD: '123456' # 设置专用用户密码
MYSQL_DATABASE: one-api # 自动创建数据库

48
go.mod
View File

@@ -1,4 +1,4 @@
module github.com/songquanpeng/one-api module one-api
// +heroku goVersion go1.18 // +heroku goVersion go1.18
go 1.18 go 1.18
@@ -9,57 +9,55 @@ require (
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
github.com/gin-contrib/static v0.0.1 github.com/gin-contrib/static v0.0.1
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
github.com/go-playground/validator/v10 v10.14.0 github.com/go-playground/validator/v10 v10.14.1
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5 github.com/pkoukk/tiktoken-go v0.1.5
github.com/stretchr/testify v1.8.3 golang.org/x/crypto v0.11.0
golang.org/x/crypto v0.17.0 gorm.io/driver/mysql v1.5.1
golang.org/x/image v0.14.0 gorm.io/driver/sqlite v1.5.2
gorm.io/driver/mysql v1.4.3 gorm.io/gorm v1.25.2
gorm.io/driver/postgres v1.5.2
gorm.io/driver/sqlite v1.4.3
gorm.io/gorm v1.25.0
) )
require ( require (
github.com/bytedance/sonic v1.9.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.4.2 // indirect
)
require (
github.com/bytedance/sonic v1.9.2 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect github.com/go-sql-driver/mysql v1.7.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.3.1 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.9 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.4.0 // indirect
golang.org/x/net v0.17.0 // indirect golang.org/x/net v0.12.0 // indirect
golang.org/x/sys v0.15.0 // indirect golang.org/x/sys v0.10.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.11.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
gorm.io/driver/postgres v1.5.2
) )

78
go.sum
View File

@@ -1,8 +1,8 @@
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.2 h1:GDaNjuWSGu09guE9Oql0MSTNhNCLlWwO8y/xM5BzcbM=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= github.com/bytedance/sonic v1.9.2/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
@@ -43,12 +43,13 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI=
github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos= github.com/go-playground/validator/v10 v10.10.0/go.mod h1:74x4gJWsvQexRdW8Pn3dXSGrTK4nAUsbPlLADvpJkos=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= github.com/go-playground/validator/v10 v10.14.1 h1:9c50NUPC30zyuKprjL3vNZ0m5oG+jU0zvx4AqHGnv4k=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.1/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.9.7/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
@@ -67,25 +68,24 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
github.com/jackc/pgx/v5 v5.4.2 h1:u1gmGDwbdRUZiwisBm/Ky2M14uQyUP65bG8+20nnyrg=
github.com/jackc/pgx/v5 v5.4.2/go.mod h1:q6iHT8uDNXWiFNOlRqJzBTaSH3+2xCXkokxHZC5qWFY=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -102,7 +102,6 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -115,8 +114,8 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.9 h1:uH2qQXheeefCCkuBBSLi7jCiSmj3VRh2+Goq2N7Xxu0=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.9/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
@@ -136,8 +135,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
@@ -147,38 +146,36 @@ github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.4.0 h1:A8WCeEWhLwPBKNbFi5Wv5UTCBx5zzubnXDlMOFAzFMc=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.4.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio=
golang.org/x/image v0.14.0 h1:tNgSxAFe3jC4uYqvZdTr84SZoM1KfwdC9SKIFrLjFn4=
golang.org/x/image v0.14.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
@@ -193,14 +190,13 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k= gorm.io/driver/mysql v1.5.1 h1:WUEH5VF9obL/lTtzjmML/5e6VfFR/788coz2uaVCAZw=
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c= gorm.io/driver/mysql v1.5.1/go.mod h1:Jo3Xu7mMhCyj8dlrb3WoCaRd1FhsVh+yMXb1jUInf5o=
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0= gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8= gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU= gorm.io/driver/sqlite v1.5.2 h1:TpQ+/dqCY4uCigCFyrfnrJnrW9zjpelWVoEVNy5qJkc=
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI= gorm.io/driver/sqlite v1.5.2/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4=
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= gorm.io/gorm v1.25.2 h1:gs1o6Vsa+oVKG/a9ElL3XgyGfghFfkKA2SInQaCyMho=
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU= gorm.io/gorm v1.25.2/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -3,11 +3,6 @@
"%d 点额度": "%d point quota", "%d 点额度": "%d point quota",
"尚未实现": "Not yet implemented", "尚未实现": "Not yet implemented",
"余额不足": "Insufficient balance", "余额不足": "Insufficient balance",
"危险操作": "Hazardous operations",
"输入你的账户名": "Enter your account name",
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"", "\"通道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s", "通道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running", "测试已在运行中": "Test is already running",
@@ -39,8 +34,8 @@
"兑换码个数必须大于0": "The number of redemption codes must be greater than 0", "兑换码个数必须大于0": "The number of redemption codes must be greater than 0",
"一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100", "一次兑换码批量生成的个数不能大于 100": "The number of redemption codes generated in a batch cannot be greater than 100",
"通过令牌「%s」使用模型 %s 消耗 %s模型倍率 %.2f,分组倍率 %.2f": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)", "通过令牌「%s」使用模型 %s 消耗 %s模型倍率 %.2f,分组倍率 %.2f": "Using model %s with token %s consumes %s (model rate %.2f, group rate %.2f)",
"当前分组上游负载已饱和,请稍后再试": "The current group load is saturated, please try again later", "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。": "The current group load is saturated, please try again later, or upgrade your account to improve service quality.",
"令牌名称过长": "Token name is too long", "令牌名称长度必须在1-20之间": "The length of the token name must be between 1-20",
"令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.", "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期": "The token has expired and cannot be enabled. Please modify the expiration time of the token, or set it to never expire.",
"令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota", "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度": "The available quota of the token has been used up and cannot be enabled. Please modify the remaining quota of the token, or set it to unlimited quota",
"管理员关闭了密码登录": "The administrator has turned off password login", "管理员关闭了密码登录": "The administrator has turned off password login",
@@ -86,7 +81,6 @@
"该令牌已过期": "The token has expired", "该令牌已过期": "The token has expired",
"该令牌额度已用尽": "The token quota has been used up", "该令牌额度已用尽": "The token quota has been used up",
"无效的令牌": "Invalid token", "无效的令牌": "Invalid token",
"令牌验证失败": "Token verification failed",
"id 或 userId 为空!": "id or userId is empty!", "id 或 userId 为空!": "id or userId is empty!",
"quota 不能为负数!": "quota cannot be negative!", "quota 不能为负数!": "quota cannot be negative!",
"令牌额度不足": "Insufficient token quota", "令牌额度不足": "Insufficient token quota",
@@ -120,7 +114,6 @@
" 年 ": " y ", " 年 ": " y ",
"未测试": "Not tested", "未测试": "Not tested",
"通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.", "通道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有通道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.", "已成功开始测试所有已启用通道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!", "通道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!", "已更新完毕所有已启用通道余额!": "The balance of all enabled channels has been updated!",
@@ -141,7 +134,6 @@
"启用": "Enable", "启用": "Enable",
"编辑": "Edit", "编辑": "Edit",
"添加新的渠道": "Add a new channel", "添加新的渠道": "Add a new channel",
"测试所有通道": "Test all channels",
"测试所有已启用通道": "Test all enabled channels", "测试所有已启用通道": "Test all enabled channels",
"更新所有已启用通道余额": "Update the balance of all enabled channels", "更新所有已启用通道余额": "Update the balance of all enabled channels",
"刷新": "Refresh", "刷新": "Refresh",
@@ -232,7 +224,7 @@
"已是最新版本": "Is the latest version", "已是最新版本": "Is the latest version",
"检查更新": "Check for updates", "检查更新": "Check for updates",
"公告": "Announcement", "公告": "Announcement",
"在此输入新的公告内容,支持 Markdown & HTML 代码": "Enter the new announcement content here, supports Markdown & HTML code", "在此输入新的公告内容": "Enter new announcement content here",
"保存公告": "Save Announcement", "保存公告": "Save Announcement",
"个性化设置": "Personalization Settings", "个性化设置": "Personalization Settings",
"系统名称": "System Name", "系统名称": "System Name",
@@ -268,7 +260,7 @@
"注意": "Note", "注意": "Note",
"此处生成的令牌用于系统管理": "The token generated here is used for system management", "此处生成的令牌用于系统管理": "The token generated here is used for system management",
"而非用于请求 OpenAI 相关的服务": "Not for requesting OpenAI related services", "而非用于请求 OpenAI 相关的服务": "Not for requesting OpenAI related services",
"请知悉": "Please be aware", "请知悉": "Please be aware.",
"更新个人信息": "Update Personal Information", "更新个人信息": "Update Personal Information",
"生成系统访问令牌": "Generate System Access Token", "生成系统访问令牌": "Generate System Access Token",
"复制邀请链接": "Copy Invitation Link", "复制邀请链接": "Copy Invitation Link",
@@ -292,7 +284,7 @@
"兑换时间": "Redemption Time", "兑换时间": "Redemption Time",
"尚未兑换": "Not yet redeemed", "尚未兑换": "Not yet redeemed",
"已复制到剪贴板": "Copied to clipboard", "已复制到剪贴板": "Copied to clipboard",
"无法复制到剪贴板": "Unable to copy to clipboard", "无法复制到剪贴板": "Unable to copy to clipboard.",
"请手动复制": "Please copy manually", "请手动复制": "Please copy manually",
"已将兑换码填入搜索框": "The voucher code has been filled into the search box", "已将兑换码填入搜索框": "The voucher code has been filled into the search box",
"复制": "Copy", "复制": "Copy",
@@ -435,7 +427,7 @@
"一分钟后过期": "Expires after one minute", "一分钟后过期": "Expires after one minute",
"创建新的令牌": "Create New Token", "创建新的令牌": "Create New Token",
"注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.", "注意,令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note that the quota of the token is only used to limit the maximum quota usage of the token itself, and the actual usage is limited by the remaining quota of the account.",
"设为无限额度": "Set to unlimited quota", "设为无限额度": "Set to unlimited quota",
"更新令牌信息": "Update Token Information", "更新令牌信息": "Update Token Information",
"请输入充值码!": "Please enter the recharge code!", "请输入充值码!": "Please enter the recharge code!",
"请输入名称": "Please enter a name", "请输入名称": "Please enter a name",
@@ -459,7 +451,6 @@
"使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})", "使用明细(总消耗额度:{renderQuota(stat.quota)}": "Usage Details (Total Consumption Quota: {renderQuota(stat.quota)})",
"用户名称": "User Name", "用户名称": "User Name",
"令牌名称": "Token Name", "令牌名称": "Token Name",
"默认令牌": "Default Token",
"留空则查询全部用户": "Leave blank to query all users", "留空则查询全部用户": "Leave blank to query all users",
"留空则查询全部令牌": "Leave blank to query all tokens", "留空则查询全部令牌": "Leave blank to query all tokens",
"模型名称": "Model Name", "模型名称": "Model Name",
@@ -502,7 +493,6 @@
"参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)", "参数替换为你的部署名称(模型名称中的点会被剔除)": "Replace the parameter with your deployment name (dots in the model name will be removed)",
"模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!", "模型映射必须是合法的 JSON 格式!": "Model mapping must be in valid JSON format!",
"取消无限额度": "Cancel unlimited quota", "取消无限额度": "Cancel unlimited quota",
"取消": "Cancel",
"请输入新的剩余额度": "Please enter the new remaining quota", "请输入新的剩余额度": "Please enter the new remaining quota",
"请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code", "请输入单个兑换码中包含的额度": "Please enter the quota included in a single redemption code",
"请输入用户名": "Please enter username", "请输入用户名": "Please enter username",
@@ -514,264 +504,72 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel", "请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL", "Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL", "Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为通道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:", "兑换中": "Redeeming",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
"注意,": "Note that, ",
",图片演示。": "related image demo.",
"令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
"代理": "Proxy",
"此项可选,用于通过代理站来进行 API 调用请输入代理站地址格式为https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
"取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
"按照如下格式输入:": "Enter in the following format:",
"模型版本": "Model version",
"请输入星火大模型版本注意是接口地址中的版本号例如v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
"点击查看": "click to view",
"请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!",
"处理中...": "Processing...",
"绑定成功!": "Binding successful!",
"登录成功!": "Login successful!",
"操作失败,重定向至登录界面中...": "Operation failed, redirecting to login screen...",
"出现错误,第 ${count} 次重试中...": "An error occurred, retrying ${count}...",
"首页": "Home",
"渠道": "Channel",
"令牌": "API Keys",
"兑换": "Redeem",
"充值": "Recharge",
"用户": "Users",
"日志": "Logs",
"设置": "Settings",
"关于": "About",
"聊天": "Chat",
"注销成功!": "Logout successful!",
"注销": "Log out",
"登录": "Log in",
"注册": "Sign up",
"加载{name}中...": "Loading {name}...",
"未登录或登录已过期,请重新登录!": "Not logged in or login has expired, please log in again!",
"请立刻修改默认密码!": "Please change the default password immediately!",
"欢迎回来": "Welcome back",
"没有账户?": "No account?",
"立刻注册": "Sign up now",
"用户名": "Username",
"密码": "Password",
"正在登录……": "Logging in...",
"忘记密码": "Forgot password",
"其他方式": "Other methods",
"微信扫码关注公众号,输入「验证码」获取验证码(三分钟内有效)": "Scan the QR code with WeChat, follow the official account and enter 'verification code' to get the verification code (valid within three minutes)",
"验证码": "Verification code",
"全部用户": "All users",
"当前用户": "Current user",
"全部": "All",
"消费": "Consumption",
"管理": "Management",
"系统": "System",
"未知": "Unknown",
"其他模型": "Other models",
"复制成功": "Copy successful",
"使用明细": "Usages",
"刷新": "Refresh",
"收起面板": "Collapse panel",
"展开面板": "Expand panel",
"显示查询选项": "Show search options",
"隐藏查询选项": "Hide search options",
"用户名称": "User name",
"可选值": "Optional values",
"渠道 ID": "Channel ID",
"令牌名称": "Key name",
"模型名称": "Model name",
"起始时间": "Start time",
"结束时间": "End time",
"查询": "Query",
"隐藏条形图": "Hide bar chart",
"显示条形图": "Show bar chart",
"折线条形图只展示最新50条数据": "Line and bar charts only show the latest 50 pieces of data",
"总消耗": "Total consumption",
"总共调用了 {payload[0].value} 次": "A total of {payload[0].value} calls were made",
"{model.name}: {model.value} 次": "{model.name}: {model.value} times",
"总共调用了 {payload[0].value} 次 {payload[0].name}": "A total of {payload[0].value} {payload[0].name} calls were made",
"总消耗额度": "Total consumption limit",
"暂无数据": "No data available",
"更多数据统计图形即将到来,敬请期待!": "More data statistics graphics are coming soon, stay tuned!",
"复制用户名": "Copy username",
"{`共 ${counts} 条数据`}": "{`A total of ${counts} pieces of data`}",
"共 0 条数据": "A total of 0 pieces of data",
"选择明细分类": "Select detail category",
"模型倍率": "model rate",
"分组倍率": "group rate",
"新密码已复制到剪贴板:": "New password has been copied to the clipboard:",
"密码重置确认": "Password reset confirmation",
"邮箱地址": "Email address",
"新密码": "New password",
"密码已复制到剪贴板:": "Password has been copied to the clipboard:",
"密码重置完成": "Password reset complete",
"提交": "Submit",
"返回登录": "Return to login",
"请稍后重试,浏览器环境检查未通过": "Please try again later, browser environment check failed",
"重置邮件发送成功,请检查邮箱!": "Reset email sent successfully, please check your email!",
"密码重置": "Password reset",
"重试": "Retry",
"组": "Group",
"令牌已重置并已复制到剪贴板": "Token has been reset and copied to the clipboard",
"邀请链接已复制到剪切板": "Invitation link has been copied to the clipboard",
"系统令牌已复制到剪切板": "System token has been copied to the clipboard",
"请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
"账户已删除!": "Account has been deleted!",
"微信账户绑定成功!": "WeChat account binding successful!",
"请稍后几秒重试Turnstile 正在检查用户环境!": "Please try again in a few seconds, Turnstile is checking the user environment!",
"验证码发送成功,请检查邮箱!": "Verification code sent successfully, please check your email!",
"邮箱账户绑定成功!": "Email account binding successful!",
"个人信息": "Personal information",
"编辑个人信息": "Edit personal information",
"生成系统访问令牌": "Generate system access token",
"复制邀请链接": "Copy invitation link",
"删除个人帐户": "Delete personal account",
"普通用户": "Regular user",
"管理员": "Administrator",
"超级管理员": "Super administrator",
"显示名称": "Display name",
"GitHub 账号": "GitHub account",
"微信账号": "WeChat account",
"修改个人信息只允许在电脑端进行。生成的令牌用于系统管理,而非用于请求 OpenAI 相关的服务,请知悉。": "Modifying personal information is only allowed on a computer. The generated token is for system management, not for requesting OpenAI related services. Please be aware.",
"可用模型": "Available models",
"账号绑定": "Account binding",
"绑定微信": "Bind WeChat",
"绑定 GitHub": "Bind GitHub",
"绑定邮箱": "Bind Email",
"绑定": "Bind",
"绑定邮箱地址": "Bind email address",
"输入邮箱地址": "Enter email address",
"重新发送": "Resend",
"获取验证码": "Get verification code",
"确认绑定": "Confirm binding",
"取消": "Cancel",
"危险操作": "Dangerous operation",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your own account, all data will be cleared and cannot be recovered",
"输入你的账户名": "Enter your account name",
"以确认删除": "To confirm deletion",
"确认删除": "Confirm deletion",
"未使用": "Not used",
"已禁用": "Disabled",
"已使用": "Used",
"未知状态": "Unknown status",
"操作成功完成!": "Operation successfully completed!",
"搜索兑换码的 ID 和名称 ...": "Search for the ID and name of the redemption code ...",
"名称": "Name",
"状态": "Status",
"额度": "Quota",
"创建时间": "Creation time",
"兑换时间": "Redemption time",
"操作": "Operation",
"尚未兑换": "Not yet redeemed",
"已复制到剪贴板!": "Copied to clipboard!",
"无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually. The redemption code has been filled in the search box.",
"复制": "Copy",
"删除": "Delete",
"禁用": "Disable",
"启用": "Enable",
"编辑": "Edit",
"添加新的兑换码": "Add new redemption code",
"密码长度不得小于 8 位!": "Password length must not be less than 8 characters!",
"两次输入的密码不一致": "The two passwords entered do not match",
"注册成功!": "Registration successful!",
"请填写注册邮箱!": "Please fill in the registration email!",
"请在${verificationTimeout}秒后再试": "Please try again after ${verificationTimeout} seconds",
"验证码发送成功,请检查你的邮箱!": "Verification code sent successfully, please check your email!",
"已有账户?": "Already have an account?",
"请输入用户名(最长 12 位)": "Please enter a username (up to 12 characters)",
"请输入密码(最短 8 位,最长 20 位)": "Please enter a password (minimum 8 characters, maximum 20 characters)",
"请再次输入密码": "Please enter the password again",
"请输入邮箱地址": "Please enter an email address",
"秒后可重发": "Can be resent after seconds",
"请输入邮箱验证码": "Please enter the email verification code",
"已过期": "Expired",
"已启用": "Enabled",
"已耗尽": "Exhausted",
"无": "None",
"令牌密钥": "API Key",
"令牌状态": "Key status",
"已用额度": "Used quota",
"剩余额度": "Remaining quota",
"过期时间": "Expiration time",
"你确定要删除这个令牌吗?": "Are you sure you want to delete this key?",
"无法复制到剪贴板,请手动复制,已将令牌密钥填入搜索框": "Unable to copy to clipboard, please copy manually. The key key has been filled in the search box.",
"无限制": "Unlimited",
"永不过期": "Never expires",
"使用 API 访问令牌进行服务鉴权和计费。": "Use API Key for service authentication and billing.",
"API 访问令牌关系到您的个人利益,请妥善留存,不要与其他人共享,也不要保存在客户端代码中。": "API Key is related to your personal interests. Please keep it properly. Do not share it with others or save it in client code.",
"创建令牌": "Create Key",
"什么都还没有,快去创建一个令牌开始使用吧!": "Nothing yet, go create a key to start using!",
"你确定要删除该令牌吗": "Are you sure you want to delete this key",
"导出令牌信息": "Export key information",
"错误:未登录或登录已过期,请重新登录!": "Error: Not logged in or login has expired, please log in again!",
"错误:请求次数过多,请稍后再试!": "Error: Too many requests, please try again later!",
"错误:服务器内部错误,请联系管理员!": "Error: Server internal error, please contact the online customer service!",
"本站仅作演示之用,无服务端!": "This site is for demonstration purposes only, no server!",
"错误:": "Error:",
"加载首页内容失败...": "Failed to load homepage content...",
"系统状况": "System status",
"系统信息": "System information",
"系统信息总览": "System information overview",
"名称:": "Name:",
"版本:": "Version:",
"源码:": "Source code:",
"启动时间:": "Startup time:",
"系统配置": "System configuration",
"系统配置总览": "System configuration overview",
"邮箱验证:": "Email verification:",
"未启用": "Not enabled",
"Turnstile 用户校验:": "Turnstile user verification:",
"页面不存在": "Page does not exist",
"请检查你的浏览器地址是否正确": "Please check if your browser address is correct",
"个人设置": "Personal settings",
"运营设置": "Operations settings",
"系统设置": "System settings",
"其他设置": "Other settings",
"默认令牌": "Default key",
"过期时间必须在当前时间之后!": "Expiration time must be after the current time!",
"额度必须大于等于 0": "Quota must be greater than or equal to 0!",
"过期时间格式错误!": "Expiration time format error!",
"创建令牌数量必须大于等于 1": "The number of keys to create must be greater than or equal to 1!",
"令牌修改成功": "API Key modification successful",
"令牌创建成功": "API Key creation successful",
"更新令牌信息": "Update key information",
"创建新的令牌": "Create a new key",
"请输入名称": "Please enter a name",
"请输入过期时间,格式为 yyyy-MM-dd HH:mm:ss-1 表示无限制": "Please enter the expiration time, the format is yyyy-MM-dd HH:mm:ss, -1 means unlimited",
"无限额度": "Unlimited quota",
"注意:启用无限额度后,已用额度将不再进行计算。": "Note: After enabling unlimited quota, the used quota will no longer be calculated.",
"等于": "Equals",
"请输入额度单位token": "Please enter the quota (unit: token)",
"创建令牌数量": "Create key quantity",
"请输入令牌数量": "Please enter the number of keys",
"注意:令牌的额度仅用于限制令牌本身的最大额度使用量,实际的使用受到账户的剩余额度限制。": "Note: The quota of the key is only used to limit the maximum quota usage of the key itself, and the actual usage is subject to the remaining quota of the account.",
"我的令牌": "My keys",
"请输入额度兑换码!": "Please enter the redeem code!",
"充值成功!": "Recharge successful!",
"请求失败": "Request failed", "请求失败": "Request failed",
"超级管理员未设置充值链接!": "The super administrator did not set a recharge link!", "百度文心千帆": "Baidu Wenxin Qianfan",
"充值额度": "Recharge quota", "智谱": "Zhipuai",
"兑换中...": "Redeeming...", "代理:": "Agent: ",
"请点击充值以获取额度兑换码。": "Please click recharge to get the quota redemption code.", "请输入你的账户名以确认删除!": "Please enter your account name to confirm deletion!",
"用户信息更新成功": "User information updated successfully!", "账户已删除": "Account deleted!",
"更新用户信息": "Update user information", "删除个人账户": "Delete personal account",
"请输入新的用户名": "Please enter a new username", "重新发送": "Resend ",
"请输入新的密码,最短 8 位": "Please enter a new password, at least 8 characters", "确认删除自己的帐户": "Confirm deletion of your own account",
"输入新的显示名称": "Please enter a new display name", "输入你的账户名": "Enter your account name",
"分组": "Group", "以确认删除": "to confirm deletion",
"请选择分组": "Please select a group", "重试": "Retry",
"请在系统设置页面编辑分组倍率以添加新的分组:": "Please edit the group rate on the system settings page to add a new group:", "无法复制到剪贴板,请手动复制,已将兑换码填入搜索框。": "Unable to copy to clipboard, please copy manually, the redemption code has been filled in the search box.",
"请输入新的剩余额度": "Please enter a new remaining quota", "密码重置完成": "Password reset completed",
"已绑定的 GitHub 账户": "Bound GitHub account",
"此项只读,需要用户通过个人设置页面的相关绑定按钮进行绑定,不可直接修改": "This item is read-only, users need to bind through the relevant binding button on the personal settings page, cannot be directly modified", "流式请求和非流式请求不能同时禁用!": "Streaming requests and non-streaming requests cannot be disabled at the same time!",
"已绑定的微信账户": "Bound WeChat account",
"已绑定的邮箱账户": "Bound email account", "请为渠道命名": "Please name the channel",
"新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面": "New version available: ${data.version}, please refresh the page using the shortcut key Shift + F5", "请选择可以使用该渠道的分组": "Please select the group that can use this channel",
"无法正常连接至服务器!": "Unable to connect to the server normally!", "请选择该渠道所支持的模型": "Please select the models supported by this channel",
"提示:": "Input:", "填入": "Fill in",
"补全:": "Output:", "此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"搜索令牌名称": "Search key name", "允许流式请求": "Allow streaming requests",
"测试所有渠道": "Test all channels", "允许非流式请求": "Allow non-streaming requests",
"更新已启用渠道余额": "Update the balance of enabled channels" "请输入 access token当前版本暂不支持自动刷新请每 30 天更新一次": "Please enter the access token, the current version does not support automatic refresh, please update it every 30 days",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",
"此项可选用于通过Mirror站来进行 API 调用请EnterMirror站地址格式为https://domain.com": "This is optional, used to make API calls through the Mirror site, please enter the Mirror site address, the format is: https://domain.com",
"新密码": "New password",
"新密码已复制到剪贴板:": "New password copied to clipboard: ",
"兑换失败,": "Redemption failed, ",
"当前分组 %s 下对于模型 %s 无可用渠道": "There are no available channels for model %s under the current group %s",
"无权将其他用户权限等级提升到大于等于自己的权限等级": "You are not allowed to raise the permission level of other users to greater than or equal to your own permission level",
"不能删除超级管理员账户": "Cannot delete super administrator account",
"无效参数": "Invalid parameter",
"访问令牌无效": "Access token invalid",
"Google 用户信息无效": "Google user information invalid",
"返回值无效,用户字段为空,请稍后再试!": "The return value is invalid, the user field is empty, please try again later!",
"管理员未开启通过 Google 登录以及注册": "The administrator has not enabled login and registration via Google",
"该 Google 账户已被绑定": "The Google account has been bound",
"点击 <a href='%s'>此处</a> 进行密码重置。": "Click <a href='%s'>here</a> to reset your password.",
"如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:<br> %s ": "If the link cannot be clicked, please try to click the link below or copy it to the browser to open: <br> %s ",
"该渠道类型当前版本不支持测试,请手动测试": "The current version of the channel type does not support testing, please test manually",
"无法启用 Discord OAuth请先填入 Discord Client ID 以及 Discord Client Secret": "Unable to enable Discord OAuth, please fill in the Discord Client ID and Discord Client Secret first!",
"无法启用 Google OAuth请先填入 Google Client ID 以及 Google Client Secret": "Unable to enable Google OAuth, please fill in the Google Client ID and Google Client Secret first!",
"管理员未开启通过 Discord 登录以及注册": "The administrator has not enabled login and registration via Discord",
"该 Discord 账户已被绑定": "The Discord account has been bound",
"Discord 用户信息无效": "Discord user information invalid",
"绑定 Discord 账号": "Bind Discord account",
"绑定 Google 账号": "Bind Google account",
"已绑定的 Discord 账户": "Bound Discord account",
"已绑定的 Google 账户": "Bound Google account",
"身份验证": "Authentication",
"允许通过 Discord 账户登录和注册": "Allow login and registration via Discord account",
"允许通过 Google 账户登录和注册": "Allow login and registration via Google account",
"配置 Discord OAuth 应用程序": "Configure Discord OAuth application",
"配置 Google OAuth 应用程序": "Configure Google OAuth application",
"用以支持通过 Discord 进行登录注册,": "Used to support login and registration via Discord, ",
"用以支持通过 Google 进行登录注册,": "Used to support login and registration via Google, ",
"管理你的": "Manage your ",
"客户 ID": "Client ID",
"客户秘密": "Client Secret",
"失败重试次数": "Retry times",
"保存": "Save",
"输入您注册的 Discord OAuth APP 的 ID": "Enter the ID of your registered Discord OAuth APP",
"输入您注册的 Google OAuth APP 的 ID": "Enter the ID of your registered Google OAuth APP"
} }

75
main.go
View File

@@ -2,107 +2,96 @@ package main
import ( import (
"embed" "embed"
"fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie" "github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "one-api/common"
"github.com/songquanpeng/one-api/common/config" "one-api/controller"
"github.com/songquanpeng/one-api/common/logger" "one-api/middleware"
"github.com/songquanpeng/one-api/controller" "one-api/model"
"github.com/songquanpeng/one-api/middleware" "one-api/router"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/router"
"os" "os"
"strconv" "strconv"
) )
//go:embed web/build/* //go:embed web/build
var buildFS embed.FS var buildFS embed.FS
//go:embed web/build/index.html
var indexPage []byte
func main() { func main() {
logger.SetupLogger() common.SetupGinLog()
logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) common.SysLog("One API " + common.Version + " started")
if os.Getenv("GIN_MODE") != "debug" { if os.Getenv("GIN_MODE") != "debug" {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
if config.DebugEnabled {
logger.SysLog("running in debug mode")
}
// Initialize SQL Database // Initialize SQL Database
err := model.InitDB() err := model.InitDB()
if err != nil { if err != nil {
logger.FatalLog("failed to initialize database: " + err.Error()) common.FatalLog("failed to initialize database: " + err.Error())
} }
defer func() { defer func() {
err := model.CloseDB() err := model.CloseDB()
if err != nil { if err != nil {
logger.FatalLog("failed to close database: " + err.Error()) common.FatalLog("failed to close database: " + err.Error())
} }
}() }()
// Initialize Redis // Initialize Redis
err = common.InitRedisClient() err = common.InitRedisClient()
if err != nil { if err != nil {
logger.FatalLog("failed to initialize Redis: " + err.Error()) common.FatalLog("failed to initialize Redis: " + err.Error())
} }
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
logger.SysLog(fmt.Sprintf("using theme %s", config.Theme))
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions
config.MemoryCacheEnabled = true
}
if config.MemoryCacheEnabled {
logger.SysLog("memory cache enabled")
logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency))
model.InitChannelCache() model.InitChannelCache()
} }
if config.MemoryCacheEnabled { if os.Getenv("SYNC_FREQUENCY") != "" {
go model.SyncOptions(config.SyncFrequency) frequency, err := strconv.Atoi(os.Getenv("SYNC_FREQUENCY"))
go model.SyncChannelCache(config.SyncFrequency) if err != nil {
common.FatalLog("failed to parse SYNC_FREQUENCY: " + err.Error())
}
common.SyncFrequency = frequency
go model.SyncOptions(frequency)
if common.RedisEnabled {
go model.SyncChannelCache(frequency)
}
} }
if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY"))
if err != nil { if err != nil {
logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error())
} }
go controller.AutomaticallyUpdateChannels(frequency) go controller.AutomaticallyUpdateChannels(frequency)
} }
if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" {
frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY"))
if err != nil { if err != nil {
logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error())
} }
go controller.AutomaticallyTestChannels(frequency) go controller.AutomaticallyTestChannels(frequency)
} }
if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
config.BatchUpdateEnabled = true
logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s")
model.InitBatchUpdater()
}
openai.InitTokenEncoders()
// Initialize HTTP server // Initialize HTTP server
server := gin.New() server := gin.Default()
server.Use(gin.Recovery())
// This will cause SSE not to work!!! // This will cause SSE not to work!!!
//server.Use(gzip.Gzip(gzip.DefaultCompression)) //server.Use(gzip.Gzip(gzip.DefaultCompression))
server.Use(middleware.RequestId()) server.Use(middleware.CORS())
middleware.SetUpLogger(server)
// Initialize session store // Initialize session store
store := cookie.NewStore([]byte(config.SessionSecret)) store := cookie.NewStore([]byte(common.SessionSecret))
server.Use(sessions.Sessions("session", store)) server.Use(sessions.Sessions("session", store))
router.SetRouter(server, buildFS) router.SetRouter(server, buildFS, indexPage)
var port = os.Getenv("PORT") var port = os.Getenv("PORT")
if port == "" { if port == "" {
port = strconv.Itoa(*common.Port) port = strconv.Itoa(*common.Port)
} }
err = server.Run(":" + port) err = server.Run(":" + port)
if err != nil { if err != nil {
logger.FatalLog("failed to start HTTP server: " + err.Error()) common.FatalLog("failed to start HTTP server: " + err.Error())
} }
} }

View File

@@ -3,9 +3,9 @@ package middleware
import ( import (
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strings" "strings"
) )
@@ -91,26 +91,45 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error()) c.JSON(http.StatusUnauthorized, gin.H{
"error": gin.H{
"message": err.Error(),
"type": "one_api_error",
},
})
c.Abort()
return return
} }
userEnabled, err := model.CacheIsUserEnabled(token.UserId) if !model.CacheIsUserEnabled(token.UserId) {
if err != nil { c.JSON(http.StatusForbidden, gin.H{
abortWithMessage(c, http.StatusInternalServerError, err.Error()) "error": gin.H{
return "message": "用户已被封禁",
} "type": "one_api_error",
if !userEnabled { },
abortWithMessage(c, http.StatusForbidden, "用户已被封禁") })
c.Abort()
return return
} }
c.Set("id", token.UserId) c.Set("id", token.UserId)
c.Set("token_id", token.Id) c.Set("token_id", token.Id)
c.Set("token_name", token.Name) c.Set("token_name", token.Name)
requestURL := c.Request.URL.String()
consumeQuota := true
if strings.HasPrefix(requestURL, "/v1/models") {
consumeQuota = false
}
c.Set("consume_quota", consumeQuota)
if len(parts) > 1 { if len(parts) > 1 {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set("channelId", parts[1]) c.Set("channelId", parts[1])
} else { } else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "普通用户不支持指定渠道",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }

View File

@@ -2,10 +2,10 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "log"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"one-api/common"
"one-api/model"
"strconv" "strconv"
"strings" "strings"
@@ -13,7 +13,8 @@ import (
) )
type ModelRequest struct { type ModelRequest struct {
Model string `json:"model"` Model string `json:"model"`
Stream bool `json:"stream" default:"true"`
} }
func Distribute() func(c *gin.Context) { func Distribute() func(c *gin.Context) {
@@ -26,16 +27,34 @@ func Distribute() func(c *gin.Context) {
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
channel, err = model.GetChannelById(id, true) channel, err = model.GetChannelById(id, true)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != common.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} else { } else {
@@ -43,7 +62,13 @@ func Distribute() func(c *gin.Context) {
var modelRequest ModelRequest var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest) err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求") c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return return
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
@@ -58,47 +83,35 @@ func Distribute() func(c *gin.Context) {
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2" modelRequest.Model = "dall-e"
} }
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { log.Print(modelRequest.Stream)
if modelRequest.Model == "" { channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, modelRequest.Stream)
modelRequest.Model = "whisper-1"
}
}
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil { if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if channel != nil { if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员" message = "数据库一致性已被破坏,请联系管理员"
} }
abortWithMessage(c, http.StatusServiceUnavailable, message) c.JSON(http.StatusServiceUnavailable, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
return return
} }
} }
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id) c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name) c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.GetModelMapping()) c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.BaseURL)
// this is for backward compatibility if channel.Type == common.ChannelTypeAzure {
switch channel.Type { c.Set("api_version", channel.Other)
case common.ChannelTypeAzure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
} }
c.Next() c.Next()
} }

View File

@@ -1,25 +0,0 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
)
func SetUpLogger(server *gin.Engine) {
server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var requestID string
if param.Keys != nil {
requestID = param.Keys[logger.RequestIdKey].(string)
}
return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
requestID,
param.StatusCode,
param.Latency,
param.ClientIP,
param.Method,
param.Path,
)
}))
}

View File

@@ -4,9 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"net/http" "net/http"
"one-api/common"
"time" "time"
) )
@@ -27,7 +26,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
} }
if listLength < int64(maxRequestNum) { if listLength < int64(maxRequestNum) {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
} else { } else {
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
oldTime, err := time.Parse(timeFormat, oldTimeStr) oldTime, err := time.Parse(timeFormat, oldTimeStr)
@@ -48,14 +47,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
// time.Since will return negative number! // time.Since will return negative number!
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
if int64(nowTime.Sub(oldTime).Seconds()) < duration { if int64(nowTime.Sub(oldTime).Seconds()) < duration {
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
c.Status(http.StatusTooManyRequests) c.Status(http.StatusTooManyRequests)
c.Abort() c.Abort()
return return
} else { } else {
rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LPush(ctx, key, time.Now().Format(timeFormat))
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration)
} }
} }
} }
@@ -76,7 +75,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
} else { } else {
// It's safe to call multi times. // It's safe to call multi times.
inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration)
return func(c *gin.Context) { return func(c *gin.Context) {
memoryRateLimiter(c, maxRequestNum, duration, mark) memoryRateLimiter(c, maxRequestNum, duration, mark)
} }
@@ -84,21 +83,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi
} }
func GlobalWebRateLimit() func(c *gin.Context) { func GlobalWebRateLimit() func(c *gin.Context) {
return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW")
} }
func GlobalAPIRateLimit() func(c *gin.Context) { func GlobalAPIRateLimit() func(c *gin.Context) {
return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA")
} }
func CriticalRateLimit() func(c *gin.Context) { func CriticalRateLimit() func(c *gin.Context) {
return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT")
} }
func DownloadRateLimit() func(c *gin.Context) { func DownloadRateLimit() func(c *gin.Context) {
return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW")
} }
func UploadRateLimit() func(c *gin.Context) { func UploadRateLimit() func(c *gin.Context) {
return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP")
} }

View File

@@ -1,28 +0,0 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"net/http"
"runtime/debug"
)
func RelayPanicRecover() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
logger.SysError(fmt.Sprintf("panic detected: %v", err))
logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack())))
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err),
"type": "one_api_panic",
},
})
c.Abort()
}
}()
c.Next()
}
}

View File

@@ -1,19 +0,0 @@
package middleware
import (
"context"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := helper.GetTimeString() + helper.GetRandomString(8)
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)
c.Header(logger.RequestIdKey, id)
c.Next()
}
}

View File

@@ -4,10 +4,9 @@ import (
"encoding/json" "encoding/json"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"net/http" "net/http"
"net/url" "net/url"
"one-api/common"
) )
type turnstileCheckResponse struct { type turnstileCheckResponse struct {
@@ -16,7 +15,7 @@ type turnstileCheckResponse struct {
func TurnstileCheck() gin.HandlerFunc { func TurnstileCheck() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if config.TurnstileCheckEnabled { if common.TurnstileCheckEnabled {
session := sessions.Default(c) session := sessions.Default(c)
turnstileChecked := session.Get("turnstile") turnstileChecked := session.Get("turnstile")
if turnstileChecked != nil { if turnstileChecked != nil {
@@ -33,12 +32,12 @@ func TurnstileCheck() gin.HandlerFunc {
return return
} }
rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{
"secret": {config.TurnstileSecretKey}, "secret": {common.TurnstileSecretKey},
"response": {response}, "response": {response},
"remoteip": {c.ClientIP()}, "remoteip": {c.ClientIP()},
}) })
if err != nil { if err != nil {
logger.SysError(err.Error()) common.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),
@@ -50,7 +49,7 @@ func TurnstileCheck() gin.HandlerFunc {
var res turnstileCheckResponse var res turnstileCheckResponse
err = json.NewDecoder(rawRes.Body).Decode(&res) err = json.NewDecoder(rawRes.Body).Decode(&res)
if err != nil { if err != nil {
logger.SysError(err.Error()) common.SysError(err.Error())
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": err.Error(), "message": err.Error(),

View File

@@ -1,18 +0,0 @@
package middleware
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
logger.Error(c.Request.Context(), message)
}

View File

@@ -1,34 +1,47 @@
package model package model
import ( import (
"github.com/songquanpeng/one-api/common" "fmt"
"one-api/common"
"strings" "strings"
) )
type Ability struct { type Ability struct {
Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"` Group string `json:"group" gorm:"type:varchar(32);primaryKey;autoIncrement:false"`
Model string `json:"model" gorm:"primaryKey;autoIncrement:false"` Model string `json:"model" gorm:"primaryKey;autoIncrement:false"`
ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` ChannelId int `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"` AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
} }
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func GetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
ability := Ability{} ability := Ability{}
groupCol := "`group`" var err error = nil
trueVal := "1"
cmd := "`group` = ? and model = ? and enabled = 1"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
groupCol = `"group"` // Make cmd compatible with PostgreSQL
trueVal = "true" cmd = "\"group\" = ? and model = ? and enabled = true"
} }
var err error = nil if stream {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model) cmd += fmt.Sprintf(" and allow_streaming = %d", common.ChannelAllowStreamEnabled)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else { } else {
err = channelQuery.Order("RAND()").First(&ability).Error cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
cmd += fmt.Sprintf(" and allow_non_streaming = %d", common.ChannelAllowNonStreamEnabled)
}
if common.UsingSQLite {
err = DB.Where(cmd, group, model).Order("RANDOM()").Limit(1).First(&ability).Error
} else {
err = DB.Where(cmd, group, model).Order("RAND()").Limit(1).First(&ability).Error
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -46,11 +59,12 @@ func (channel *Channel) AddAbilities() error {
for _, model := range models_ { for _, model := range models_ {
for _, group := range groups_ { for _, group := range groups_ {
ability := Ability{ ability := Ability{
Group: group, Group: group,
Model: model, Model: model,
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == common.ChannelStatusEnabled,
Priority: channel.Priority, AllowStreaming: channel.AllowStreaming,
AllowNonStreaming: channel.AllowNonStreaming,
} }
abilities = append(abilities, ability) abilities = append(abilities, ability)
} }

View File

@@ -4,11 +4,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"math/rand" "math/rand"
"sort" "one-api/common"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -16,25 +13,27 @@ import (
) )
var ( var (
TokenCacheSeconds = config.SyncFrequency TokenCacheSeconds = common.SyncFrequency
UserId2GroupCacheSeconds = config.SyncFrequency UserId2GroupCacheSeconds = common.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = common.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = common.SyncFrequency
) )
func CacheGetTokenByKey(key string) (*Token, error) { func CacheGetTokenByKey(key string) (*Token, error) {
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
var token Token var token Token
whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
if !common.RedisEnabled { if !common.RedisEnabled {
err := DB.Where(keyCol+" = ?", key).First(&token).Error err := DB.Where(whereItem, key).First(&token).Error
return &token, err return &token, err
} }
tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key)) tokenObjectString, err := common.RedisGet(fmt.Sprintf("token:%s", key))
if err != nil { if err != nil {
err := DB.Where(keyCol+" = ?", key).First(&token).Error err := DB.Where(whereItem, key).First(&token).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -44,7 +43,7 @@ func CacheGetTokenByKey(key string) (*Token, error) {
} }
err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second)
if err != nil { if err != nil {
logger.SysError("Redis set token error: " + err.Error()) common.SysError("Redis set token error: " + err.Error())
} }
return &token, nil return &token, nil
} }
@@ -64,7 +63,7 @@ func CacheGetUserGroup(id int) (group string, err error) {
} }
err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
if err != nil { if err != nil {
logger.SysError("Redis set user group error: " + err.Error()) common.SysError("Redis set user group error: " + err.Error())
} }
} }
return group, err return group, err
@@ -82,7 +81,7 @@ func CacheGetUserQuota(id int) (quota int, err error) {
} }
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil { if err != nil {
logger.SysError("Redis set user quota error: " + err.Error()) common.SysError("Redis set user quota error: " + err.Error())
} }
return quota, err return quota, err
} }
@@ -102,36 +101,23 @@ func CacheUpdateUserQuota(id int) error {
return err return err
} }
func CacheDecreaseUserQuota(id int, quota int) error { func CacheIsUserEnabled(userId int) bool {
if !common.RedisEnabled {
return nil
}
err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
return err
}
func CacheIsUserEnabled(userId int) (bool, error) {
if !common.RedisEnabled { if !common.RedisEnabled {
return IsUserEnabled(userId) return IsUserEnabled(userId)
} }
enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId)) enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
if err == nil {
return enabled == "1", nil
}
userEnabled, err := IsUserEnabled(userId)
if err != nil { if err != nil {
return false, err status := common.UserStatusDisabled
if IsUserEnabled(userId) {
status = common.UserStatusEnabled
}
enabled = fmt.Sprintf("%d", status)
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
common.SysError("Redis set user enabled error: " + err.Error())
}
} }
enabled = "0" return enabled == "1"
if userEnabled {
enabled = "1"
}
err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set user enabled error: " + err.Error())
}
return userEnabled, err
} }
var group2model2channels map[string]map[string][]*Channel var group2model2channels map[string]map[string][]*Channel
@@ -166,34 +152,23 @@ func InitChannelCache() {
} }
} }
} }
// sort by priority
for group, model2channels := range newGroup2model2channels {
for model, channels := range model2channels {
sort.Slice(channels, func(i, j int) bool {
return channels[i].GetPriority() > channels[j].GetPriority()
})
newGroup2model2channels[group][model] = channels
}
}
channelSyncLock.Lock() channelSyncLock.Lock()
group2model2channels = newGroup2model2channels group2model2channels = newGroup2model2channels
channelSyncLock.Unlock() channelSyncLock.Unlock()
logger.SysLog("channels synced from database") common.SysLog("channels synced from database")
} }
func SyncChannelCache(frequency int) { func SyncChannelCache(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Second) time.Sleep(time.Duration(frequency) * time.Second)
logger.SysLog("syncing channels from database") common.SysLog("syncing channels from database")
InitChannelCache() InitChannelCache()
} }
} }
func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { func CacheGetRandomSatisfiedChannel(group string, model string, stream bool) (*Channel, error) {
if !config.MemoryCacheEnabled { if !common.RedisEnabled {
return GetRandomSatisfiedChannel(group, model) return GetRandomSatisfiedChannel(group, model, stream)
} }
channelSyncLock.RLock() channelSyncLock.RLock()
defer channelSyncLock.RUnlock() defer channelSyncLock.RUnlock()
@@ -201,17 +176,14 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error
if len(channels) == 0 { if len(channels) == 0 {
return nil, errors.New("channel not found") return nil, errors.New("channel not found")
} }
endIdx := len(channels)
// choose by priority var filteredChannels []*Channel
firstChannel := channels[0] for _, channel := range channels {
if firstChannel.GetPriority() > 0 { if (stream && channel.AllowStreaming == common.ChannelAllowStreamEnabled) || (!stream && channel.AllowNonStreaming == common.ChannelAllowNonStreamEnabled) {
for i := range channels { filteredChannels = append(filteredChannels, channel)
if channels[i].GetPriority() != firstChannel.GetPriority() {
endIdx = i
break
}
} }
} }
idx := rand.Intn(endIdx)
return channels[idx], nil idx := rand.Intn(len(filteredChannels))
return filteredChannels[idx], nil
} }

View File

@@ -1,12 +1,8 @@
package model package model
import ( import (
"encoding/json" "one-api/common"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -16,20 +12,20 @@ type Channel struct {
Key string `json:"key" gorm:"not null;index"` Key string `json:"key" gorm:"not null;index"`
Status int `json:"status" gorm:"default:1"` Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"` Name string `json:"name" gorm:"index"`
Weight *uint `json:"weight" gorm:"default:0"` Weight int `json:"weight"`
CreatedTime int64 `json:"created_time" gorm:"bigint"` CreatedTime int64 `json:"created_time" gorm:"bigint"`
TestTime int64 `json:"test_time" gorm:"bigint"` TestTime int64 `json:"test_time" gorm:"bigint"`
ResponseTime int `json:"response_time"` // in milliseconds ResponseTime int `json:"response_time"` // in milliseconds
BaseURL *string `json:"base_url" gorm:"column:base_url;default:''"` BaseURL string `json:"base_url" gorm:"column:base_url"`
Other string `json:"other"` // DEPRECATED: please save config to field Config Other string `json:"other"`
Balance float64 `json:"balance"` // in USD Balance float64 `json:"balance"` // in USD
BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"` BalanceUpdatedTime int64 `json:"balance_updated_time" gorm:"bigint"`
Models string `json:"models"` Models string `json:"models"`
Group string `json:"group" gorm:"type:varchar(32);default:'default'"` Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` ModelMapping string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
Priority *int64 `json:"priority" gorm:"bigint;default:0"` AllowStreaming int `json:"allow_streaming" gorm:"default:1"`
Config string `json:"config"` AllowNonStreaming int `json:"allow_non_streaming" gorm:"default:1"`
} }
func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
@@ -44,11 +40,13 @@ func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) {
} }
func SearchChannels(keyword string) (channels []*Channel, err error) { func SearchChannels(keyword string) (channels []*Channel, err error) {
keyCol := "`key`" whereItem := "id = ? or name LIKE ? or `key` = ?"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
keyCol = `"key"` whereItem = "id = ? or name LIKE ? or \"key\" = ?"
} }
err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error
err = DB.Omit("key").Where(whereItem, keyword, keyword+"%", keyword).Find(&channels).Error
return channels, err return channels, err
} }
@@ -63,6 +61,19 @@ func GetChannelById(id int, selectAll bool) (*Channel, error) {
return &channel, err return &channel, err
} }
func GetRandomChannel() (*Channel, error) {
channel := Channel{}
var err error = nil
if common.UsingPostgreSQL {
err = DB.Where("status = ? and \"group\" = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else if common.UsingSQLite {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RANDOM()").Limit(1).First(&channel).Error
} else {
err = DB.Where("status = ? and `group` = ?", common.ChannelStatusEnabled, "default").Order("RAND()").Limit(1).First(&channel).Error
}
return &channel, err
}
func BatchInsertChannels(channels []Channel) error { func BatchInsertChannels(channels []Channel) error {
var err error var err error
err = DB.Create(&channels).Error err = DB.Create(&channels).Error
@@ -78,33 +89,6 @@ func BatchInsertChannels(channels []Channel) error {
return nil return nil
} }
func (channel *Channel) GetPriority() int64 {
if channel.Priority == nil {
return 0
}
return *channel.Priority
}
func (channel *Channel) GetBaseURL() string {
if channel.BaseURL == nil {
return ""
}
return *channel.BaseURL
}
func (channel *Channel) GetModelMapping() map[string]string {
if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" {
return nil
}
modelMapping := make(map[string]string)
err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping)
if err != nil {
logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error()))
return nil
}
return modelMapping
}
func (channel *Channel) Insert() error { func (channel *Channel) Insert() error {
var err error var err error
err = DB.Create(channel).Error err = DB.Create(channel).Error
@@ -128,21 +112,21 @@ func (channel *Channel) Update() error {
func (channel *Channel) UpdateResponseTime(responseTime int64) { func (channel *Channel) UpdateResponseTime(responseTime int64) {
err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{
TestTime: helper.GetTimestamp(), TestTime: common.GetTimestamp(),
ResponseTime: int(responseTime), ResponseTime: int(responseTime),
}).Error }).Error
if err != nil { if err != nil {
logger.SysError("failed to update response time: " + err.Error()) common.SysError("failed to update response time: " + err.Error())
} }
} }
func (channel *Channel) UpdateBalance(balance float64) { func (channel *Channel) UpdateBalance(balance float64) {
err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{
BalanceUpdatedTime: helper.GetTimestamp(), BalanceUpdatedTime: common.GetTimestamp(),
Balance: balance, Balance: balance,
}).Error }).Error
if err != nil { if err != nil {
logger.SysError("failed to update balance: " + err.Error()) common.SysError("failed to update balance: " + err.Error())
} }
} }
@@ -156,50 +140,20 @@ func (channel *Channel) Delete() error {
return err return err
} }
func (channel *Channel) LoadConfig() (map[string]string, error) {
if channel.Config == "" {
return nil, nil
}
cfg := make(map[string]string)
err := json.Unmarshal([]byte(channel.Config), &cfg)
if err != nil {
return nil, err
}
return cfg, nil
}
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
if err != nil { if err != nil {
logger.SysError("failed to update ability status: " + err.Error()) common.SysError("failed to update ability status: " + err.Error())
} }
err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error
if err != nil { if err != nil {
logger.SysError("failed to update channel status: " + err.Error()) common.SysError("failed to update channel status: " + err.Error())
} }
} }
func UpdateChannelUsedQuota(id int, quota int) { func UpdateChannelUsedQuota(id int, quota int) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
}
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil { if err != nil {
logger.SysError("failed to update channel used quota: " + err.Error()) common.SysError("failed to update channel used quota: " + err.Error())
} }
} }
func DeleteChannelByStatus(status int64) (int64, error) {
result := DB.Where("status = ?", status).Delete(&Channel{})
return result.RowsAffected, result.Error
}
func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}

View File

@@ -1,29 +1,22 @@
package model package model
import ( import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Log struct { type Log struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id" gorm:"index"` UserId int `json:"user_id"`
CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` CreatedAt int64 `json:"created_at" gorm:"bigint;index"`
Type int `json:"type" gorm:"index:idx_created_at_type"` Type int `json:"type" gorm:"index"`
Content string `json:"content"` Content string `json:"content"`
Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` Username string `json:"username" gorm:"index;default:''"`
TokenName string `json:"token_name" gorm:"index;default:''"` TokenName string `json:"token_name" gorm:"index;default:''"`
ModelName string `json:"model_name" gorm:"index;index:index_username_model_name,priority:1;default:''"` ModelName string `json:"model_name" gorm:"index;default:''"`
Quota int `json:"quota" gorm:"default:0"` Quota int `json:"quota" gorm:"default:0"`
PromptTokens int `json:"prompt_tokens" gorm:"default:0"` PromptTokens int `json:"prompt_tokens" gorm:"default:0"`
CompletionTokens int `json:"completion_tokens" gorm:"default:0"` CompletionTokens int `json:"completion_tokens" gorm:"default:0"`
ChannelId int `json:"channel" gorm:"index"`
} }
const ( const (
@@ -35,31 +28,30 @@ const (
) )
func RecordLog(userId int, logType int, content string) { func RecordLog(userId int, logType int, content string) {
if logType == LogTypeConsume && !config.LogConsumeEnabled { if logType == LogTypeConsume && !common.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(), CreatedAt: common.GetTimestamp(),
Type: logType, Type: logType,
Content: content, Content: content,
} }
err := DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
logger.SysError("failed to record log: " + err.Error()) common.SysError("failed to record log: " + err.Error())
} }
} }
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { func RecordConsumeLog(userId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !common.LogConsumeEnabled {
if !config.LogConsumeEnabled {
return return
} }
log := &Log{ log := &Log{
UserId: userId, UserId: userId,
Username: GetUsernameById(userId), Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(), CreatedAt: common.GetTimestamp(),
Type: LogTypeConsume, Type: LogTypeConsume,
Content: content, Content: content,
PromptTokens: promptTokens, PromptTokens: promptTokens,
@@ -67,15 +59,14 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
TokenName: tokenName, TokenName: tokenName,
ModelName: modelName, ModelName: modelName,
Quota: quota, Quota: quota,
ChannelId: channelId,
} }
err := DB.Create(log).Error err := DB.Create(log).Error
if err != nil { if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error()) common.SysError("failed to record log: " + err.Error())
} }
} }
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB var tx *gorm.DB
if logType == LogTypeUnknown { if logType == LogTypeUnknown {
tx = DB tx = DB
@@ -97,9 +88,6 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
if endTimestamp != 0 { if endTimestamp != 0 {
tx = tx.Where("created_at <= ?", endTimestamp) tx = tx.Where("created_at <= ?", endTimestamp)
} }
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error err = tx.Order("id desc").Limit(num).Offset(startIdx).Find(&logs).Error
return logs, err return logs, err
} }
@@ -128,17 +116,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
} }
func SearchAllLogs(keyword string) (logs []*Log, err error) { func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error
return logs, err return logs, err
} }
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err return logs, err
} }
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) { func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (quota int) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)") tx := DB.Table("logs").Select("sum(quota)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -154,15 +142,12 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
if modelName != "" { if modelName != "" {
tx = tx.Where("model_name = ?", modelName) tx = tx.Where("model_name = ?", modelName)
} }
if channel != 0 {
tx = tx.Where("channel_id = ?", channel)
}
tx.Where("type = ?", LogTypeConsume).Scan(&quota) tx.Where("type = ?", LogTypeConsume).Scan(&quota)
return quota return quota
} }
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) { func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)") tx := DB.Table("logs").Select("sum(prompt_tokens) + sum(completion_tokens)")
if username != "" { if username != "" {
tx = tx.Where("username = ?", username) tx = tx.Where("username = ?", username)
} }
@@ -181,45 +166,3 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
tx.Where("type = ?", LogTypeConsume).Scan(&token) tx.Where("type = ?", LogTypeConsume).Scan(&token)
return token return token
} }
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
type LogStatistic struct {
Day string `gorm:"column:day"`
ModelName string `gorm:"column:model_name"`
RequestCount int `gorm:"column:request_count"`
Quota int `gorm:"column:quota"`
PromptTokens int `gorm:"column:prompt_tokens"`
CompletionTokens int `gorm:"column:completion_tokens"`
}
func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatistic, err error) {
groupSelect := "DATE_FORMAT(FROM_UNIXTIME(created_at), '%Y-%m-%d') as day"
if common.UsingPostgreSQL {
groupSelect = "TO_CHAR(date_trunc('day', to_timestamp(created_at)), 'YYYY-MM-DD') as day"
}
if common.UsingSQLite {
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
}
err = DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,
sum(prompt_tokens) as prompt_tokens,
sum(completion_tokens) as completion_tokens
FROM logs
WHERE type=2
AND user_id= ?
AND created_at BETWEEN ? AND ?
GROUP BY day, model_name
ORDER BY day, model_name
`, userId, start, end).Scan(&LogStatistics).Error
return LogStatistics, err
}

View File

@@ -1,27 +1,22 @@
package model package model
import ( import (
"fmt" "one-api/common"
"github.com/songquanpeng/one-api/common" "os"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"os"
"strings"
"time"
) )
var DB *gorm.DB var DB *gorm.DB
func createRootAccountIfNeed() error { func createRootAccountIfNeed() error {
var user User var user User
//if user.Status != util.UserStatusEnabled { //if user.Status != common.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil { if err := DB.First(&user).Error; err != nil {
logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") common.SysLog("no user exists, create a root user for you: username is root, password is 123456")
hashedPassword, err := common.Password2Hash("123456") hashedPassword, err := common.Password2Hash("123456")
if err != nil { if err != nil {
return err return err
@@ -32,7 +27,7 @@ func createRootAccountIfNeed() error {
Role: common.RoleRootUser, Role: common.RoleRootUser,
Status: common.UserStatusEnabled, Status: common.UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: helper.GetUUID(), AccessToken: common.GetUUID(),
Quota: 100000000, Quota: 100000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -40,55 +35,41 @@ func createRootAccountIfNeed() error {
return nil return nil
} }
func chooseDB() (*gorm.DB, error) { func CountTable(tableName string) (num int64) {
if os.Getenv("SQL_DSN") != "" { DB.Table(tableName).Count(&num)
dsn := os.Getenv("SQL_DSN") return
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
logger.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
return gorm.Open(postgres.New(postgres.Config{
DSN: dsn,
PreferSimpleProtocol: true, // disables implicit prepared statement usage
}), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use MySQL
logger.SysLog("using MySQL as database")
return gorm.Open(mysql.Open(dsn), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
// Use SQLite
logger.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout)
return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} }
func InitDB() (err error) { func InitDB() (err error) {
db, err := chooseDB() var db *gorm.DB
if os.Getenv("POSTGRES_DSN") != "" {
// Use PostgreSQL
common.SysLog("using PostgreSQL as database")
common.UsingPostgreSQL = true
db, err = gorm.Open(postgres.Open(os.Getenv("POSTGRES_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else if os.Getenv("SQL_DSN") != "" {
// Use MySQL
common.SysLog("using MySQL as database")
db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
} else {
// Use SQLite
common.SysLog("SQL_DSN not set, using SQLite as database")
common.UsingSQLite = true
db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
PrepareStmt: true, // precompile SQL
})
}
common.SysLog("database connected")
if err == nil { if err == nil {
if config.DebugEnabled {
db = db.Debug()
}
DB = db DB = db
sqlDB, err := DB.DB() if !common.IsMasterNode {
if err != nil {
return err
}
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode {
return nil return nil
} }
logger.SysLog("database migration started") err := db.AutoMigrate(&Channel{})
err = db.AutoMigrate(&Channel{})
if err != nil { if err != nil {
return err return err
} }
@@ -116,11 +97,11 @@ func InitDB() (err error) {
if err != nil { if err != nil {
return err return err
} }
logger.SysLog("database migrated") common.SysLog("database migrated")
err = createRootAccountIfNeed() err = createRootAccountIfNeed()
return err return err
} else { } else {
logger.FatalLog(err) common.FatalLog(err)
} }
return err return err
} }

View File

@@ -1,9 +1,7 @@
package model package model
import ( import (
"github.com/songquanpeng/one-api/common" "one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -22,57 +20,62 @@ func AllOption() ([]*Option, error) {
} }
func InitOptionMap() { func InitOptionMap() {
config.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
config.OptionMap = make(map[string]string) common.OptionMap = make(map[string]string)
config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission)
config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission)
config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission)
config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission)
config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled)
config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled)
config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled)
config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled)
config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) common.OptionMap["DiscordOAuthEnabled"] = strconv.FormatBool(common.DiscordOAuthEnabled)
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled)
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) common.OptionMap["GoogleOAuthEnabled"] = strconv.FormatBool(common.GoogleOAuthEnabled)
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled)
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled)
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled)
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled)
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled)
config.OptionMap["SMTPServer"] = "" common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled)
config.OptionMap["SMTPFrom"] = "" common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled)
config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64)
config.OptionMap["SMTPAccount"] = "" common.OptionMap["SMTPServer"] = ""
config.OptionMap["SMTPToken"] = "" common.OptionMap["SMTPFrom"] = ""
config.OptionMap["Notice"] = "" common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort)
config.OptionMap["About"] = "" common.OptionMap["SMTPAccount"] = ""
config.OptionMap["HomePageContent"] = "" common.OptionMap["SMTPToken"] = ""
config.OptionMap["Footer"] = config.Footer common.OptionMap["Notice"] = ""
config.OptionMap["SystemName"] = config.SystemName common.OptionMap["About"] = ""
config.OptionMap["Logo"] = config.Logo common.OptionMap["HomePageContent"] = ""
config.OptionMap["ServerAddress"] = "" common.OptionMap["Footer"] = common.Footer
config.OptionMap["GitHubClientId"] = "" common.OptionMap["SystemName"] = common.SystemName
config.OptionMap["GitHubClientSecret"] = "" common.OptionMap["Logo"] = common.Logo
config.OptionMap["WeChatServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
config.OptionMap["WeChatServerToken"] = "" common.OptionMap["GitHubClientId"] = ""
config.OptionMap["WeChatAccountQRCodeImageURL"] = "" common.OptionMap["GitHubClientSecret"] = ""
config.OptionMap["TurnstileSiteKey"] = "" common.OptionMap["DiscordClientId"] = ""
config.OptionMap["TurnstileSecretKey"] = "" common.OptionMap["DiscordClientSecret"] = ""
config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) common.OptionMap["WeChatServerAddress"] = ""
config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) common.OptionMap["WeChatServerToken"] = ""
config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) common.OptionMap["WeChatAccountQRCodeImageURL"] = ""
config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) common.OptionMap["GoogleClientId"] = ""
config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) common.OptionMap["GoogleClientSecret"] = ""
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() common.OptionMap["TurnstileSiteKey"] = ""
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() common.OptionMap["TurnstileSecretKey"] = ""
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser)
config.OptionMap["TopUpLink"] = config.TopUpLink common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter)
config.OptionMap["ChatLink"] = config.ChatLink common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee)
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
config.OptionMap["Theme"] = config.Theme common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMapRWMutex.Unlock() common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
common.OptionMap["TopUpLink"] = common.TopUpLink
common.OptionMap["ChatLink"] = common.ChatLink
common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64)
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
@@ -81,7 +84,7 @@ func loadOptionsFromDatabase() {
for _, option := range options { for _, option := range options {
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
logger.SysError("failed to update option map: " + err.Error()) common.SysError("failed to update option map: " + err.Error())
} }
} }
} }
@@ -89,7 +92,7 @@ func loadOptionsFromDatabase() {
func SyncOptions(frequency int) { func SyncOptions(frequency int) {
for { for {
time.Sleep(time.Duration(frequency) * time.Second) time.Sleep(time.Duration(frequency) * time.Second)
logger.SysLog("syncing options from database") common.SysLog("syncing options from database")
loadOptionsFromDatabase() loadOptionsFromDatabase()
} }
} }
@@ -111,106 +114,121 @@ func UpdateOption(key string, value string) error {
} }
func updateOptionMap(key string, value string) (err error) { func updateOptionMap(key string, value string) (err error) {
config.OptionMapRWMutex.Lock() common.OptionMapRWMutex.Lock()
defer config.OptionMapRWMutex.Unlock() defer common.OptionMapRWMutex.Unlock()
config.OptionMap[key] = value common.OptionMap[key] = value
if strings.HasSuffix(key, "Permission") {
intValue, _ := strconv.Atoi(value)
switch key {
case "FileUploadPermission":
common.FileUploadPermission = intValue
case "FileDownloadPermission":
common.FileDownloadPermission = intValue
case "ImageUploadPermission":
common.ImageUploadPermission = intValue
case "ImageDownloadPermission":
common.ImageDownloadPermission = intValue
}
}
if strings.HasSuffix(key, "Enabled") { if strings.HasSuffix(key, "Enabled") {
boolValue := value == "true" boolValue := value == "true"
switch key { switch key {
case "PasswordRegisterEnabled": case "PasswordRegisterEnabled":
config.PasswordRegisterEnabled = boolValue common.PasswordRegisterEnabled = boolValue
case "PasswordLoginEnabled": case "PasswordLoginEnabled":
config.PasswordLoginEnabled = boolValue common.PasswordLoginEnabled = boolValue
case "EmailVerificationEnabled": case "EmailVerificationEnabled":
config.EmailVerificationEnabled = boolValue common.EmailVerificationEnabled = boolValue
case "GitHubOAuthEnabled": case "GitHubOAuthEnabled":
config.GitHubOAuthEnabled = boolValue common.GitHubOAuthEnabled = boolValue
case "DiscordOAuthEnabled":
common.DiscordOAuthEnabled = boolValue
case "WeChatAuthEnabled": case "WeChatAuthEnabled":
config.WeChatAuthEnabled = boolValue common.WeChatAuthEnabled = boolValue
case "GoogleOAuthEnabled":
common.GoogleOAuthEnabled = boolValue
case "TurnstileCheckEnabled": case "TurnstileCheckEnabled":
config.TurnstileCheckEnabled = boolValue common.TurnstileCheckEnabled = boolValue
case "RegisterEnabled": case "RegisterEnabled":
config.RegisterEnabled = boolValue common.RegisterEnabled = boolValue
case "EmailDomainRestrictionEnabled":
config.EmailDomainRestrictionEnabled = boolValue
case "AutomaticDisableChannelEnabled": case "AutomaticDisableChannelEnabled":
config.AutomaticDisableChannelEnabled = boolValue common.AutomaticDisableChannelEnabled = boolValue
case "AutomaticEnableChannelEnabled":
config.AutomaticEnableChannelEnabled = boolValue
case "ApproximateTokenEnabled": case "ApproximateTokenEnabled":
config.ApproximateTokenEnabled = boolValue common.ApproximateTokenEnabled = boolValue
case "LogConsumeEnabled": case "LogConsumeEnabled":
config.LogConsumeEnabled = boolValue common.LogConsumeEnabled = boolValue
case "DisplayInCurrencyEnabled": case "DisplayInCurrencyEnabled":
config.DisplayInCurrencyEnabled = boolValue common.DisplayInCurrencyEnabled = boolValue
case "DisplayTokenStatEnabled": case "DisplayTokenStatEnabled":
config.DisplayTokenStatEnabled = boolValue common.DisplayTokenStatEnabled = boolValue
} }
} }
switch key { switch key {
case "EmailDomainWhitelist":
config.EmailDomainWhitelist = strings.Split(value, ",")
case "SMTPServer": case "SMTPServer":
config.SMTPServer = value common.SMTPServer = value
case "SMTPPort": case "SMTPPort":
intValue, _ := strconv.Atoi(value) intValue, _ := strconv.Atoi(value)
config.SMTPPort = intValue common.SMTPPort = intValue
case "SMTPAccount": case "SMTPAccount":
config.SMTPAccount = value common.SMTPAccount = value
case "SMTPFrom": case "SMTPFrom":
config.SMTPFrom = value common.SMTPFrom = value
case "SMTPToken": case "SMTPToken":
config.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
config.ServerAddress = value common.ServerAddress = value
case "GitHubClientId": case "GitHubClientId":
config.GitHubClientId = value common.GitHubClientId = value
case "GitHubClientSecret": case "GitHubClientSecret":
config.GitHubClientSecret = value common.GitHubClientSecret = value
case "DiscordClientId":
common.DiscordClientId = value
case "DiscordClientSecret":
common.DiscordClientSecret = value
case "Footer": case "Footer":
config.Footer = value common.Footer = value
case "SystemName": case "SystemName":
config.SystemName = value common.SystemName = value
case "Logo": case "Logo":
config.Logo = value common.Logo = value
case "WeChatServerAddress": case "WeChatServerAddress":
config.WeChatServerAddress = value common.WeChatServerAddress = value
case "WeChatServerToken": case "WeChatServerToken":
config.WeChatServerToken = value common.WeChatServerToken = value
case "WeChatAccountQRCodeImageURL": case "WeChatAccountQRCodeImageURL":
config.WeChatAccountQRCodeImageURL = value common.WeChatAccountQRCodeImageURL = value
case "GoogleClientId":
common.GoogleClientId = value
case "GoogleClientSecret":
common.GoogleClientSecret = value
case "TurnstileSiteKey": case "TurnstileSiteKey":
config.TurnstileSiteKey = value common.TurnstileSiteKey = value
case "TurnstileSecretKey": case "TurnstileSecretKey":
config.TurnstileSecretKey = value common.TurnstileSecretKey = value
case "QuotaForNewUser": case "QuotaForNewUser":
config.QuotaForNewUser, _ = strconv.Atoi(value) common.QuotaForNewUser, _ = strconv.Atoi(value)
case "QuotaForInviter": case "QuotaForInviter":
config.QuotaForInviter, _ = strconv.Atoi(value) common.QuotaForInviter, _ = strconv.Atoi(value)
case "QuotaForInvitee": case "QuotaForInvitee":
config.QuotaForInvitee, _ = strconv.Atoi(value) common.QuotaForInvitee, _ = strconv.Atoi(value)
case "QuotaRemindThreshold": case "QuotaRemindThreshold":
config.QuotaRemindThreshold, _ = strconv.Atoi(value) common.QuotaRemindThreshold, _ = strconv.Atoi(value)
case "PreConsumedQuota": case "PreConsumedQuota":
config.PreConsumedQuota, _ = strconv.Atoi(value) common.PreConsumedQuota, _ = strconv.Atoi(value)
case "RetryTimes": case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value) common.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = common.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = common.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
config.TopUpLink = value common.TopUpLink = value
case "ChatLink": case "ChatLink":
config.ChatLink = value common.ChatLink = value
case "ChannelDisableThreshold": case "ChannelDisableThreshold":
config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64)
case "QuotaPerUnit": case "QuotaPerUnit":
config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "Theme":
config.Theme = value
} }
return err return err
} }

View File

@@ -3,8 +3,8 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -51,13 +51,15 @@ func Redeem(key string, userId int) (quota int, err error) {
} }
redemption := &Redemption{} redemption := &Redemption{}
keyCol := "`key`"
if common.UsingPostgreSQL {
keyCol = `"key"`
}
err = DB.Transaction(func(tx *gorm.DB) error { err = DB.Transaction(func(tx *gorm.DB) error {
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(keyCol+" = ?", key).First(redemption).Error whereItem := "`key` = ?"
if common.UsingPostgreSQL {
// Make cmd compatible with PostgreSQL
whereItem = "\"key\" = ?"
}
err := tx.Set("gorm:query_option", "FOR UPDATE").Where(whereItem, key).First(redemption).Error
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }
@@ -68,7 +70,7 @@ func Redeem(key string, userId int) (quota int, err error) {
if err != nil { if err != nil {
return err return err
} }
redemption.RedeemedTime = helper.GetTimestamp() redemption.RedeemedTime = common.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = common.RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err

View File

@@ -3,11 +3,8 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common"
) )
type Token struct { type Token struct {
@@ -41,43 +38,36 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("未提供令牌") return nil, errors.New("未提供令牌")
} }
token, err = CacheGetTokenByKey(key) token, err = CacheGetTokenByKey(key)
if err != nil { if err == nil {
logger.SysError("CacheGetTokenByKey failed: " + err.Error()) if token.Status != common.TokenStatusEnabled {
if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("该令牌状态不可用")
return nil, errors.New("无效的令牌")
} }
return nil, errors.New("令牌验证失败") if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
}
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("该令牌额度已用尽")
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
if token.Status != common.TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = common.TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
return nil, errors.New("该令牌已过期")
} }
return nil, errors.New("该令牌已过期") if !token.UnlimitedQuota && token.RemainQuota <= 0 {
}
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = common.TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) common.SysError("failed to update token status" + err.Error())
} }
return nil, errors.New("该令牌额度已用尽")
} }
return nil, errors.New("该令牌额度已用尽") go func() {
token.AccessedTime = common.GetTimestamp()
err := token.SelectUpdate()
if err != nil {
common.SysError("failed to update token" + err.Error())
}
}()
return token, nil
} }
return token, nil return nil, errors.New("无效的令牌")
} }
func GetTokenByIds(id int, userId int) (*Token, error) { func GetTokenByIds(id int, userId int) (*Token, error) {
@@ -141,19 +131,10 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
return nil
}
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota), "remain_quota": gorm.Expr("remain_quota + ?", quota),
"used_quota": gorm.Expr("used_quota - ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota),
"accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
@@ -163,19 +144,10 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
return nil
}
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates( err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota), "remain_quota": gorm.Expr("remain_quota - ?", quota),
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"accessed_time": helper.GetTimestamp(),
}, },
).Error ).Error
return err return err
@@ -199,24 +171,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
if userQuota < quota { if userQuota < quota {
return errors.New("用户额度不足") return errors.New("用户额度不足")
} }
quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold
noMoreQuota := userQuota-quota <= 0 noMoreQuota := userQuota-quota <= 0
if quotaTooLow || noMoreQuota { if quotaTooLow || noMoreQuota {
go func() { go func() {
email, err := GetUserEmail(token.UserId) email, err := GetUserEmail(token.UserId)
if err != nil { if err != nil {
logger.SysError("failed to fetch user email: " + err.Error()) common.SysError("failed to fetch user email: " + err.Error())
} }
prompt := "您的额度即将用尽" prompt := "您的额度即将用尽"
if noMoreQuota { if noMoreQuota {
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {
logger.SysError("failed to send email" + err.Error()) common.SysError("failed to send email" + err.Error())
} }
} }
}() }()

View File

@@ -3,12 +3,10 @@ package model
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"strings" "strings"
"gorm.io/gorm"
) )
// User if you add sensitive fields, don't forget to clean them in setupLogin function. // User if you add sensitive fields, don't forget to clean them in setupLogin function.
@@ -18,11 +16,13 @@ type User struct {
Username string `json:"username" gorm:"unique;index" validate:"max=12"` Username string `json:"username" gorm:"unique;index" validate:"max=12"`
Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"`
Role int `json:"role" gorm:"type:int;default:1"` // admin, util Role int `json:"role" gorm:"type:int;default:1"` // admin, common
Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled
Email string `json:"email" gorm:"index" validate:"max=50"` Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"` GitHubId string `json:"github_id" gorm:"column:github_id;index"`
DiscordId string `json:"discord_id" gorm:"column:discord_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
GoogleId string `json:"google_id" gorm:"column:google_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"type:int;default:0"` Quota int `json:"quota" gorm:"type:int;default:0"`
@@ -45,11 +45,7 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
} }
func SearchUsers(keyword string) (users []*User, err error) { func SearchUsers(keyword string) (users []*User, err error) {
if !common.UsingPostgreSQL { err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
} else {
err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
}
return users, err return users, err
} }
@@ -92,24 +88,24 @@ func (user *User) Insert(inviterId int) error {
return err return err
} }
} }
user.Quota = config.QuotaForNewUser user.Quota = common.QuotaForNewUser
user.AccessToken = helper.GetUUID() user.AccessToken = common.GetUUID()
user.AffCode = helper.GetRandomString(4) user.AffCode = common.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
if config.QuotaForNewUser > 0 { if common.QuotaForNewUser > 0 {
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser)))
} }
if inviterId != 0 { if inviterId != 0 {
if config.QuotaForInvitee > 0 { if common.QuotaForInvitee > 0 {
_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee) _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee)
RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee))) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee)))
} }
if config.QuotaForInviter > 0 { if common.QuotaForInviter > 0 {
_ = IncreaseUserQuota(inviterId, config.QuotaForInviter) _ = IncreaseUserQuota(inviterId, common.QuotaForInviter)
RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter))) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter)))
} }
} }
return nil return nil
@@ -144,15 +140,7 @@ func (user *User) ValidateAndFill() (err error) {
if user.Username == "" || password == "" { if user.Username == "" || password == "" {
return errors.New("用户名或密码为空") return errors.New("用户名或密码为空")
} }
err = DB.Where("username = ?", user.Username).First(user).Error DB.Where(User{Username: user.Username}).First(user)
if err != nil {
// we must make sure check username firstly
// consider this case: a malicious user set his username as other's email
err := DB.Where("email = ?", user.Username).First(user).Error
if err != nil {
return errors.New("用户名或密码错误,或用户已被封禁")
}
}
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != common.UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
@@ -184,6 +172,14 @@ func (user *User) FillUserByGitHubId() error {
return nil return nil
} }
func (user *User) FillUserByDiscordId() error {
if user.DiscordId == "" {
return errors.New("Discord id 为空!")
}
DB.Where(User{DiscordId: user.DiscordId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error { func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" { if user.WeChatId == "" {
return errors.New("WeChat id 为空!") return errors.New("WeChat id 为空!")
@@ -192,6 +188,14 @@ func (user *User) FillUserByWeChatId() error {
return nil return nil
} }
func (user *User) FillUserByGoogleId() error {
if user.GoogleId == "" {
return errors.New("Google id 为空!")
}
DB.Where(User{GoogleId: user.GoogleId}).First(user)
return nil
}
func (user *User) FillUserByUsername() error { func (user *User) FillUserByUsername() error {
if user.Username == "" { if user.Username == "" {
return errors.New("username 为空!") return errors.New("username 为空!")
@@ -208,6 +212,14 @@ func IsWeChatIdAlreadyTaken(wechatId string) bool {
return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
} }
func IsDiscordIdAlreadyTaken(discordId string) bool {
return DB.Where("discord_id = ?", discordId).Find(&User{}).RowsAffected == 1
}
func IsGoogleIdAlreadyTaken(googleId string) bool {
return DB.Where("google_id = ?", googleId).Find(&User{}).RowsAffected == 1
}
func IsGitHubIdAlreadyTaken(githubId string) bool { func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
} }
@@ -235,22 +247,23 @@ func IsAdmin(userId int) bool {
var user User var user User
err := DB.Where("id = ?", userId).Select("role").Find(&user).Error err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
if err != nil { if err != nil {
logger.SysError("no such user " + err.Error()) common.SysError("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= common.RoleAdminUser
} }
func IsUserEnabled(userId int) (bool, error) { func IsUserEnabled(userId int) bool {
if userId == 0 { if userId == 0 {
return false, errors.New("user id is empty") return false
} }
var user User var user User
err := DB.Where("id = ?", userId).Select("status").Find(&user).Error err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
if err != nil { if err != nil {
return false, err common.SysError("no such user " + err.Error())
return false
} }
return user.Status == common.UserStatusEnabled, nil return user.Status == common.UserStatusEnabled
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
@@ -281,12 +294,13 @@ func GetUserEmail(id int) (email string, err error) {
} }
func GetUserGroup(id int) (group string, err error) { func GetUserGroup(id int) (group string, err error) {
groupCol := "`group`" selectItem := "`group`"
if common.UsingPostgreSQL { if common.UsingPostgreSQL {
groupCol = `"group"` selectItem = "\"group\""
} }
err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error err = DB.Model(&User{}).Where("id = ?", id).Select(selectItem).Find(&group).Error
return group, err return group, err
} }
@@ -294,14 +308,6 @@ func IncreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, quota)
return nil
}
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err return err
} }
@@ -310,14 +316,6 @@ func DecreaseUserQuota(id int, quota int) (err error) {
if quota < 0 { if quota < 0 {
return errors.New("quota 不能为负数!") return errors.New("quota 不能为负数!")
} }
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
return nil
}
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err return err
} }
@@ -328,41 +326,14 @@ func GetRootUserEmail() (email string) {
} }
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
return
}
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates( err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{ map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota),
"request_count": gorm.Expr("request_count + ?", count), "request_count": gorm.Expr("request_count + ?", 1),
}, },
).Error ).Error
if err != nil { if err != nil {
logger.SysError("failed to update user used quota and request count: " + err.Error()) common.SysError("failed to update user used quota and request count: " + err.Error())
}
}
func updateUserUsedQuota(id int, quota int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
},
).Error
if err != nil {
logger.SysError("failed to update user used quota: " + err.Error())
}
}
func updateUserRequestCount(id int, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
if err != nil {
logger.SysError("failed to update user request count: " + err.Error())
} }
} }

View File

@@ -1,78 +0,0 @@
package model
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"sync"
"time"
)
const (
BatchUpdateTypeUserQuota = iota
BatchUpdateTypeTokenQuota
BatchUpdateTypeUsedQuota
BatchUpdateTypeChannelUsedQuota
BatchUpdateTypeRequestCount
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
func InitBatchUpdater() {
go func() {
for {
time.Sleep(time.Duration(config.BatchUpdateInterval) * time.Second)
batchUpdate()
}
}()
}
func addNewRecord(type_ int, id int, value int) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
batchUpdateStores[type_][id] = value
} else {
batchUpdateStores[type_][id] += value
}
}
func batchUpdate() {
logger.SysLog("batch update started")
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
switch i {
case BatchUpdateTypeUserQuota:
err := increaseUserQuota(key, value)
if err != nil {
logger.SysError("failed to batch update user quota: " + err.Error())
}
case BatchUpdateTypeTokenQuota:
err := increaseTokenQuota(key, value)
if err != nil {
logger.SysError("failed to batch update token quota: " + err.Error())
}
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}
}
}
logger.SysLog("batch update finished")
}

View File

@@ -1,9 +0,0 @@
[//]: # (请按照以下格式关联 issue)
[//]: # (请在提交 PR 前确认所提交的功能可用,附上截图即可,这将有助于项目维护者 review & merge 该 PR谢谢)
[//]: # (项目维护者一般仅在周末处理 PR因此如若未能及时回复希望能理解)
[//]: # (开发者交流群910657413)
[//]: # (请在提交 PR 之前删除上面的注释)
close #issue_number
我已确认该 PR 已自测通过,相关截图如下:

View File

@@ -1,8 +0,0 @@
package ai360
var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}

View File

@@ -1,60 +0,0 @@
package aiproxy
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "aiproxy"
}

View File

@@ -1,9 +0,0 @@
package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai"
var ModelList = []string{""}
func init() {
ModelList = openai.ModelList
}

View File

@@ -1,197 +0,0 @@
package aiproxy
import (
"bufio"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
"strings"
)
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
func ConvertRequest(request model.GeneralOpenAIRequest) *LibraryRequest {
query := ""
if len(request.Messages) != 0 {
query = request.Messages[len(request.Messages)-1].StringContent()
}
return &LibraryRequest{
Model: request.Model,
Stream: request.Stream,
Query: query,
}
}
func aiProxyDocuments2Markdown(documents []LibraryDocument) string {
if len(documents) == 0 {
return ""
}
content := "\n\n参考文档\n"
for i, document := range documents {
content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
}
return content
}
func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextResponse {
content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: content,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: helper.GetUUID(),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
}
return &fullTextResponse
}
func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
Id: helper.GetUUID(),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
Id: helper.GetUUID(),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
var documents []LibraryDocument
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var AIProxyLibraryResponse LibraryStreamResponse
err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if len(AIProxyLibraryResponse.Documents) != 0 {
documents = AIProxyLibraryResponse.Documents
}
response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
response := documentsAIProxyLibrary(documents)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var AIProxyLibraryResponse LibraryResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if AIProxyLibraryResponse.ErrCode != 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: AIProxyLibraryResponse.Message,
Type: strconv.Itoa(AIProxyLibraryResponse.ErrCode),
Code: AIProxyLibraryResponse.ErrCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &fullTextResponse.Usage
}

View File

@@ -1,32 +0,0 @@
package aiproxy
type LibraryRequest struct {
Model string `json:"model"`
Query string `json:"query"`
LibraryId string `json:"libraryId"`
Stream bool `json:"stream"`
}
type LibraryError struct {
ErrCode int `json:"errCode"`
Message string `json:"message"`
}
type LibraryDocument struct {
Title string `json:"title"`
URL string `json:"url"`
}
type LibraryResponse struct {
Success bool `json:"success"`
Answer string `json:"answer"`
Documents []LibraryDocument `json:"documents"`
LibraryError
}
type LibraryStreamResponse struct {
Content string `json:"content"`
Finish bool `json:"finish"`
Model string `json:"model"`
Documents []LibraryDocument `json:"documents"`
}

View File

@@ -1,83 +0,0 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -1,6 +0,0 @@
package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
}

View File

@@ -1,257 +0,0 @@
package ali
import (
"bufio"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, Message{
Content: message.StringContent(),
Role: strings.ToLower(message.Role),
})
}
enableSearch := false
aliModel := request.Model
if strings.HasSuffix(aliModel, EnableSearchModelSuffix) {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
return &ChatRequest{
Model: aliModel,
Input: Input{
Messages: messages,
},
Parameters: Parameters{
EnableSearch: enableSearch,
IncrementalOutput: request.Stream,
Seed: uint64(request.Seed),
},
}
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "text-embedding-v1",
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: "text-embedding-v1",
Usage: model.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "qwen",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
//lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse ChatResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
fullTextResponse.Model = "qwen"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -1,71 +0,0 @@
package ali
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -1,65 +0,0 @@
package anthropic
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/v1/complete", meta.BaseURL), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Header.Set("anthropic-version", anthropicVersion)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "authropic"
}

View File

@@ -1,5 +0,0 @@
package anthropic
var ModelList = []string{
"claude-instant-1", "claude-2", "claude-2.0", "claude-2.1",
}

View File

@@ -1,29 +0,0 @@
package anthropic
type Metadata struct {
UserId string `json:"user_id"`
}
type Request struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
MaxTokensToSample int `json:"max_tokens_to_sample"`
StopSequences []string `json:"stop_sequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
//Metadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type Error struct {
Type string `json:"type"`
Message string `json:"message"`
}
type Response struct {
Completion string `json:"completion"`
StopReason string `json:"stop_reason"`
Model string `json:"model"`
Error Error `json:"error"`
}

View File

@@ -1,93 +0,0 @@
package baidu
import (
"errors"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
var fullRequestURL string
switch meta.ActualModelName {
case "ERNIE-Bot-4":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"
case "ERNIE-Bot-8K":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"
case "ERNIE-Bot":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"
case "ERNIE-Speed":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed"
case "ERNIE-Bot-turbo":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant"
case "BLOOMZ-7B":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1"
case "Embedding-V1":
fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1"
}
var accessToken string
var err error
if accessToken, err = GetAccessToken(meta.APIKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "baidu"
}

View File

@@ -1,10 +0,0 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
}

View File

@@ -1,321 +0,0 @@
package baidu
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
"sync"
"time"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
type TokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRequest struct {
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
var baiduTokenStore sync.Map
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
for _, message := range request.Messages {
if message.Role == "system" {
messages = append(messages, Message{
Role: "user",
Content: message.StringContent(),
})
messages = append(messages, Message{
Role: "assistant",
Content: "Okay",
})
} else {
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &ChatRequest{
Messages: messages,
Stream: request.Stream,
}
}
func responseBaidu2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []openai.TextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *ChatStreamResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = baiduResponse.Result
if baiduResponse.IsEnd {
choice.FinishReason = &constant.StopFinishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
data = data[6:]
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var baiduResponse ChatStreamResponse
err := json.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
fullTextResponse.Model = "ernie-bot"
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var baiduResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if baiduResponse.ErrorMsg != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: baiduResponse.ErrorMsg,
Type: "baidu_error",
Param: "",
Code: baiduResponse.ErrorCode,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func GetAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken AccessToken
if accessToken, ok = val.(AccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = getBaiduAccessTokenHelper(apiKey)
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := getBaiduAccessTokenHelper(apiKey)
if err != nil {
return "", err
}
if accessToken == nil {
return "", errors.New("GetAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
parts[0], parts[1]), nil)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := util.ImpatientHTTPClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var accessToken AccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

View File

@@ -1,50 +0,0 @@
package baidu
import (
"github.com/songquanpeng/one-api/relay/model"
"time"
)
type ChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage model.Usage `json:"usage"`
Error
}
type ChatStreamResponse struct {
ChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type EmbeddingRequest struct {
Input []string `json:"input"`
}
type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type EmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []EmbeddingData `json:"data"`
Usage model.Usage `json:"usage"`
Error
}
type AccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}

View File

@@ -1,51 +0,0 @@
package channel
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
req.Header.Set("Accept", "text/event-stream")
}
}
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, req, meta)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := DoRequest(c, req)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}

View File

@@ -1,66 +0,0 @@
package gemini
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
}
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
usage = openai.ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "google gemini"
}

View File

@@ -1,6 +0,0 @@
package gemini
var ModelList = []string{
"gemini-pro",
"gemini-pro-vision",
}

View File

@@ -1,303 +0,0 @@
package gemini
import (
"bufio"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
const (
VisionMaxImageNum = 16
)
// Setting safety to the lowest possible values since Gemini is already powerless enough
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
geminiRequest := ChatRequest{
Contents: make([]ChatContent, 0, len(textRequest.Messages)),
SafetySettings: []ChatSafetySettings{
{
Category: "HARM_CATEGORY_HARASSMENT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_HATE_SPEECH",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT",
Threshold: config.GeminiSafetySetting,
},
{
Category: "HARM_CATEGORY_DANGEROUS_CONTENT",
Threshold: config.GeminiSafetySetting,
},
},
GenerationConfig: ChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
},
}
if textRequest.Functions != nil {
geminiRequest.Tools = []ChatTools{
{
FunctionDeclarations: textRequest.Functions,
},
}
}
shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
content := ChatContent{
Role: message.Role,
Parts: []Part{
{
Text: message.StringContent(),
},
},
}
openaiContent := message.ParseContent()
var parts []Part
imageNum := 0
for _, part := range openaiContent {
if part.Type == model.ContentTypeText {
parts = append(parts, Part{
Text: part.Text,
})
} else if part.Type == model.ContentTypeImageURL {
imageNum += 1
if imageNum > VisionMaxImageNum {
continue
}
mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
Data: data,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
// Converting system prompt to prompt from user for the same reason
if content.Role == "system" {
content.Role = "user"
shouldAddDummyModelMessage = true
}
geminiRequest.Contents = append(geminiRequest.Contents, content)
// If a system message is the last message, we need to add a dummy model message to make gemini happy
if shouldAddDummyModelMessage {
geminiRequest.Contents = append(geminiRequest.Contents, ChatContent{
Role: "model",
Parts: []Part{
{
Text: "Okay",
},
},
})
shouldAddDummyModelMessage = false
}
}
return &geminiRequest
}
type ChatResponse struct {
Candidates []ChatCandidate `json:"candidates"`
PromptFeedback ChatPromptFeedback `json:"promptFeedback"`
}
func (g *ChatResponse) GetResponseText() string {
if g == nil {
return ""
}
if len(g.Candidates) > 0 && len(g.Candidates[0].Content.Parts) > 0 {
return g.Candidates[0].Content.Parts[0].Text
}
return ""
}
type ChatCandidate struct {
Content ChatContent `json:"content"`
FinishReason string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
type ChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type ChatPromptFeedback struct {
SafetyRatings []ChatSafetyRating `json:"safetyRatings"`
}
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := openai.TextResponseChoice{
Index: i,
Message: model.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.StopFinishReason,
}
if len(candidate.Content.Parts) > 0 {
choice.Message.Content = candidate.Content.Parts[0].Text
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = geminiResponse.GetResponseText()
choice.FinishReason = &constant.StopFinishReason
var response openai.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "gemini"
response.Choices = []openai.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
dataChan := make(chan string)
stopChan := make(chan bool)
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
go func() {
for scanner.Scan() {
data := scanner.Text()
data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") {
continue
}
data = strings.TrimPrefix(data, "\"text\": \"")
data = strings.TrimSuffix(data, "\"")
dataChan <- data
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
// this is used to prevent annoying \ related format bug
data = fmt.Sprintf("{\"content\": \"%s\"}", data)
type dummyStruct struct {
Content string `json:"content"`
}
var dummy dummyStruct
err := json.Unmarshal([]byte(data), &dummy)
responseText += dummy.Content
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "gemini-pro",
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
return nil, responseText
}
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
var geminiResponse ChatResponse
err = json.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if len(geminiResponse.Candidates) == 0 {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: "No candidates returned",
Type: "server_error",
Param: "",
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
fullTextResponse.Model = modelName
completionTokens := openai.CountTokenText(geminiResponse.GetResponseText(), modelName)
usage := model.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &usage
}

View File

@@ -1,41 +0,0 @@
package gemini
type ChatRequest struct {
Contents []ChatContent `json:"contents"`
SafetySettings []ChatSafetySettings `json:"safety_settings,omitempty"`
GenerationConfig ChatGenerationConfig `json:"generation_config,omitempty"`
Tools []ChatTools `json:"tools,omitempty"`
}
type InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
}
type ChatContent struct {
Role string `json:"role,omitempty"`
Parts []Part `json:"parts"`
}
type ChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type ChatTools struct {
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type ChatGenerationConfig struct {
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
}

View File

@@ -1,20 +0,0 @@
package channel
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *util.RelayMeta)
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

Some files were not shown because too many files have changed in this diff Show More