mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	Compare commits
	
		
			47 Commits
		
	
	
		
			v0.5.3-alp
			...
			v0.5.5-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					420c375140 | ||
| 
						 | 
					01863d3e44 | ||
| 
						 | 
					d0a0e871e1 | ||
| 
						 | 
					bd6fe1e93c | ||
| 
						 | 
					c55bb67818 | ||
| 
						 | 
					0f949c3782 | ||
| 
						 | 
					a721a5b6f9 | ||
| 
						 | 
					276163affd | ||
| 
						 | 
					621eb91b46 | ||
| 
						 | 
					7e575abb95 | ||
| 
						 | 
					9db93316c4 | ||
| 
						 | 
					c3dc315e75 | ||
| 
						 | 
					04acdb1ccb | ||
| 
						 | 
					f0d5e102a3 | ||
| 
						 | 
					abbf2fded0 | ||
| 
						 | 
					ef2c5abb5b | ||
| 
						 | 
					56b5007379 | ||
| 
						 | 
					d09d317459 | ||
| 
						 | 
					1c4409ae80 | ||
| 
						 | 
					5ee24e8acf | ||
| 
						 | 
					4f2f911e4d | ||
| 
						 | 
					fdb2cccf65 | ||
| 
						 | 
					a3e267df7e | ||
| 
						 | 
					ac7c0f3a76 | ||
| 
						 | 
					efeb9a16ce | ||
| 
						 | 
					05e4f2b439 | ||
| 
						 | 
					7e058bfb9b | ||
| 
						 | 
					dfaa0183b7 | ||
| 
						 | 
					1b56becfaa | ||
| 
						 | 
					23b1c63538 | ||
| 
						 | 
					49d1a63402 | ||
| 
						 | 
					2a7b82650c | ||
| 
						 | 
					8ea7b9aae2 | ||
| 
						 | 
					5136b12612 | ||
| 
						 | 
					80a49e01a3 | ||
| 
						 | 
					8fb082ba3b | ||
| 
						 | 
					86c2627c24 | ||
| 
						 | 
					90b4cac7f3 | ||
| 
						 | 
					e4bacc45d6 | ||
| 
						 | 
					da1d81998f | ||
| 
						 | 
					cac61b9f66 | ||
| 
						 | 
					3da12e99d9 | ||
| 
						 | 
					4ef5e2020c | ||
| 
						 | 
					af20063a8d | ||
| 
						 | 
					ca512f6a38 | ||
| 
						 | 
					0e9ff8825e | ||
| 
						 | 
					e0b4f96b5b | 
@@ -1,9 +1,10 @@
 | 
			
		||||
FROM node:16 as builder
 | 
			
		||||
 | 
			
		||||
WORKDIR /build
 | 
			
		||||
COPY web/package.json .
 | 
			
		||||
RUN npm install
 | 
			
		||||
COPY ./web .
 | 
			
		||||
COPY ./VERSION .
 | 
			
		||||
RUN npm install
 | 
			
		||||
RUN DISABLE_ESLINT_PLUGIN='true' REACT_APP_VERSION=$(cat VERSION) npm run build
 | 
			
		||||
 | 
			
		||||
FROM golang AS builder2
 | 
			
		||||
@@ -13,9 +14,10 @@ ENV GO111MODULE=on \
 | 
			
		||||
    GOOS=linux
 | 
			
		||||
 | 
			
		||||
WORKDIR /build
 | 
			
		||||
ADD go.mod go.sum ./
 | 
			
		||||
RUN go mod download
 | 
			
		||||
COPY . .
 | 
			
		||||
COPY --from=builder /build/build ./web/build
 | 
			
		||||
RUN go mod download
 | 
			
		||||
RUN go build -ldflags "-s -w -X 'one-api/common.Version=$(cat VERSION)' -extldflags '-static'" -o one-api
 | 
			
		||||
 | 
			
		||||
FROM alpine
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
<p align="right">
 | 
			
		||||
    <a href="./README.md">中文</a> | <strong>English</strong>
 | 
			
		||||
    <a href="./README.md">中文</a> | <strong>English</strong> | <a href="./README.ja.md">日本語</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
<p align="center">
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										298
									
								
								README.ja.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										298
									
								
								README.ja.md
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,298 @@
 | 
			
		||||
<p align="right">
 | 
			
		||||
    <a href="./README.md">中文</a> | <a href="./README.en.md">English</a> | <strong>日本語</strong>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
<p align="center">
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api"><img src="https://raw.githubusercontent.com/songquanpeng/one-api/main/web/public/logo.png" width="150" height="150" alt="one-api logo"></a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
<div align="center">
 | 
			
		||||
 | 
			
		||||
# One API
 | 
			
		||||
 | 
			
		||||
_✨ 標準的な OpenAI API フォーマットを通じてすべての LLM にアクセスでき、導入と利用が容易です ✨_
 | 
			
		||||
 | 
			
		||||
</div>
 | 
			
		||||
 | 
			
		||||
<p align="center">
 | 
			
		||||
  <a href="https://raw.githubusercontent.com/songquanpeng/one-api/main/LICENSE">
 | 
			
		||||
    <img src="https://img.shields.io/github/license/songquanpeng/one-api?color=brightgreen" alt="license">
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api/releases/latest">
 | 
			
		||||
    <img src="https://img.shields.io/github/v/release/songquanpeng/one-api?color=brightgreen&include_prereleases" alt="release">
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://hub.docker.com/repository/docker/justsong/one-api">
 | 
			
		||||
    <img src="https://img.shields.io/docker/pulls/justsong/one-api?color=brightgreen" alt="docker pull">
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api/releases/latest">
 | 
			
		||||
    <img src="https://img.shields.io/github/downloads/songquanpeng/one-api/total?color=brightgreen&include_prereleases" alt="release">
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://goreportcard.com/report/github.com/songquanpeng/one-api">
 | 
			
		||||
    <img src="https://goreportcard.com/badge/github.com/songquanpeng/one-api" alt="GoReportCard">
 | 
			
		||||
  </a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
<p align="center">
 | 
			
		||||
  <a href="#deployment">デプロイチュートリアル</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="#usage">使用方法</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://github.com/songquanpeng/one-api/issues">フィードバック</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="#screenshots">スクリーンショット</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://openai.justsong.cn/">ライブデモ</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="#faq">FAQ</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="#related-projects">関連プロジェクト</a>
 | 
			
		||||
  ·
 | 
			
		||||
  <a href="https://iamazing.cn/page/reward">寄付</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **警告**: この README は ChatGPT によって翻訳されています。翻訳ミスを発見した場合は遠慮なく PR を投稿してください。
 | 
			
		||||
 | 
			
		||||
> **警告**: 英語版の Docker イメージは `justsong/one-api-en` です。
 | 
			
		||||
 | 
			
		||||
> **注**: Docker からプルされた最新のイメージは、`alpha` リリースかもしれません。安定性が必要な場合は、手動でバージョンを指定してください。
 | 
			
		||||
 | 
			
		||||
## 特徴
 | 
			
		||||
1. 複数の大型モデルをサポート:
 | 
			
		||||
   + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート)
 | 
			
		||||
   + [x] [Anthropic Claude シリーズモデル](https://anthropic.com)
 | 
			
		||||
   + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google)
 | 
			
		||||
   + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html)
 | 
			
		||||
   + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html)
 | 
			
		||||
   + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn)
 | 
			
		||||
2. **ロードバランシング**による複数チャンネルへのアクセスをサポート。
 | 
			
		||||
3. ストリーム伝送によるタイプライター的効果を可能にする**ストリームモード**に対応。
 | 
			
		||||
4. **マルチマシンデプロイ**に対応。[詳細はこちら](#multi-machine-deployment)を参照。
 | 
			
		||||
5. トークンの有効期限や使用回数を設定できる**トークン管理**に対応しています。
 | 
			
		||||
6. **バウチャー管理**に対応しており、バウチャーの一括生成やエクスポートが可能です。バウチャーは口座残高の補充に利用できます。
 | 
			
		||||
7. **チャンネル管理**に対応し、チャンネルの一括作成が可能。
 | 
			
		||||
8. グループごとに異なるレートを設定するための**ユーザーグループ**と**チャンネルグループ**をサポートしています。
 | 
			
		||||
9. チャンネル**モデルリスト設定**に対応。
 | 
			
		||||
10. **クォータ詳細チェック**をサポート。
 | 
			
		||||
11. **ユーザー招待報酬**をサポートします。
 | 
			
		||||
12. 米ドルでの残高表示が可能。
 | 
			
		||||
13. 新規ユーザー向けのお知らせ公開、リチャージリンク設定、初期残高設定に対応。
 | 
			
		||||
14. 豊富な**カスタマイズ**オプションを提供します:
 | 
			
		||||
    1. システム名、ロゴ、フッターのカスタマイズが可能。
 | 
			
		||||
    2. HTML と Markdown コードを使用したホームページとアバウトページのカスタマイズ、または iframe を介したスタンドアロンウェブページの埋め込みをサポートしています。
 | 
			
		||||
15. システム・アクセストークンによる管理 API アクセスをサポートする。
 | 
			
		||||
16. Cloudflare Turnstile によるユーザー認証に対応。
 | 
			
		||||
17. ユーザー管理と複数のユーザーログイン/登録方法をサポート:
 | 
			
		||||
    + 電子メールによるログイン/登録とパスワードリセット。
 | 
			
		||||
    + [GitHub OAuth](https://github.com/settings/applications/new)。
 | 
			
		||||
    + WeChat 公式アカウントの認証([WeChat Server](https://github.com/songquanpeng/wechat-server)の追加導入が必要)。
 | 
			
		||||
18. 他の主要なモデル API が利用可能になった場合、即座にサポートし、カプセル化する。
 | 
			
		||||
 | 
			
		||||
## デプロイメント
 | 
			
		||||
### Docker デプロイメント
 | 
			
		||||
デプロイコマンド: `docker run --name one-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai -v /home/ubuntu/data/one-api:/data justsong/one-api-en`。
 | 
			
		||||
 | 
			
		||||
コマンドを更新する: `docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrr/watchtower -cR`。
 | 
			
		||||
 | 
			
		||||
`-p 3000:3000` の最初の `3000` はホストのポートで、必要に応じて変更できます。
 | 
			
		||||
 | 
			
		||||
データはホストの `/home/ubuntu/data/one-api` ディレクトリに保存される。このディレクトリが存在し、書き込み権限があることを確認する、もしくは適切なディレクトリに変更してください。
 | 
			
		||||
 | 
			
		||||
Nginxリファレンス設定:
 | 
			
		||||
```
 | 
			
		||||
server{
 | 
			
		||||
   server_name openai.justsong.cn;  # ドメイン名は適宜変更
 | 
			
		||||
 | 
			
		||||
   location / {
 | 
			
		||||
          client_max_body_size  64m;
 | 
			
		||||
          proxy_http_version 1.1;
 | 
			
		||||
          proxy_pass http://localhost:3000;  # それに応じてポートを変更
 | 
			
		||||
          proxy_set_header Host $host;
 | 
			
		||||
          proxy_set_header X-Forwarded-For $remote_addr;
 | 
			
		||||
          proxy_cache_bypass $http_upgrade;
 | 
			
		||||
          proxy_set_header Accept-Encoding gzip;
 | 
			
		||||
          proxy_read_timeout 300s;  # GPT-4 はより長いタイムアウトが必要
 | 
			
		||||
   }
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
次に、Let's Encrypt certbot を使って HTTPS を設定します:
 | 
			
		||||
```bash
 | 
			
		||||
# Ubuntu に certbot をインストール:
 | 
			
		||||
sudo snap install --classic certbot
 | 
			
		||||
sudo ln -s /snap/bin/certbot /usr/bin/certbot
 | 
			
		||||
# 証明書の生成と Nginx 設定の変更
 | 
			
		||||
sudo certbot --nginx
 | 
			
		||||
# プロンプトに従う
 | 
			
		||||
# Nginx を再起動
 | 
			
		||||
sudo service nginx restart
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
初期アカウントのユーザー名は `root` で、パスワードは `123456` です。
 | 
			
		||||
 | 
			
		||||
### マニュアルデプロイ
 | 
			
		||||
1. [GitHub Releases](https://github.com/songquanpeng/one-api/releases/latest) から実行ファイルをダウンロードする、もしくはソースからコンパイルする:
 | 
			
		||||
   ```shell
 | 
			
		||||
   git clone https://github.com/songquanpeng/one-api.git
 | 
			
		||||
 | 
			
		||||
   # フロントエンドのビルド
 | 
			
		||||
   cd one-api/web
 | 
			
		||||
   npm install
 | 
			
		||||
   npm run build
 | 
			
		||||
 | 
			
		||||
   # バックエンドのビルド
 | 
			
		||||
   cd ..
 | 
			
		||||
   go mod download
 | 
			
		||||
   go build -ldflags "-s -w" -o one-api
 | 
			
		||||
   ```
 | 
			
		||||
2. 実行:
 | 
			
		||||
   ```shell
 | 
			
		||||
   chmod u+x one-api
 | 
			
		||||
   ./one-api --port 3000 --log-dir ./logs
 | 
			
		||||
   ```
 | 
			
		||||
3. [http://localhost:3000/](http://localhost:3000/) にアクセスし、ログインする。初期アカウントのユーザー名は `root`、パスワードは `123456` である。
 | 
			
		||||
 | 
			
		||||
より詳細なデプロイのチュートリアルについては、[このページ](https://iamazing.cn/page/how-to-deploy-a-website) を参照してください。
 | 
			
		||||
 | 
			
		||||
### マルチマシンデプロイ
 | 
			
		||||
1. すべてのサーバに同じ `SESSION_SECRET` を設定する。
 | 
			
		||||
2. `SQL_DSN` を設定し、SQLite の代わりに MySQL を使用する。すべてのサーバは同じデータベースに接続する。
 | 
			
		||||
3. マスターノード以外のノードの `NODE_TYPE` を `slave` に設定する。
 | 
			
		||||
4. データベースから定期的に設定を同期するサーバーには `SYNC_FREQUENCY` を設定する。
 | 
			
		||||
5. マスター以外のノードでは、オプションで `FRONTEND_BASE_URL` を設定して、ページ要求をマスターサーバーにリダイレクトすることができます。
 | 
			
		||||
6. マスター以外のノードには Redis を個別にインストールし、`REDIS_CONN_STRING` を設定して、キャッシュの有効期限が切れていないときにデータベースにゼロレイテンシーでアクセスできるようにする。
 | 
			
		||||
7. メインサーバーでもデータベースへのアクセスが高レイテンシになる場合は、Redis を有効にし、`SYNC_FREQUENCY` を設定してデータベースから定期的に設定を同期する必要がある。
 | 
			
		||||
 | 
			
		||||
Please refer to the [environment variables](#environment-variables) section for details on using environment variables.
 | 
			
		||||
 | 
			
		||||
### コントロールパネル(例: Baota)への展開
 | 
			
		||||
詳しい手順は [#175](https://github.com/songquanpeng/one-api/issues/175) を参照してください。
 | 
			
		||||
 | 
			
		||||
配置後に空白のページが表示される場合は、[#97](https://github.com/songquanpeng/one-api/issues/97) を参照してください。
 | 
			
		||||
 | 
			
		||||
### サードパーティプラットフォームへのデプロイ
 | 
			
		||||
<details>
 | 
			
		||||
<summary><strong>Sealos へのデプロイ</strong></summary>
 | 
			
		||||
<div>
 | 
			
		||||
 | 
			
		||||
> Sealos は、高い同時実行性、ダイナミックなスケーリング、数百万人のユーザーに対する安定した運用をサポートしています。
 | 
			
		||||
 | 
			
		||||
> 下のボタンをクリックすると、ワンクリックで展開できます。👇
 | 
			
		||||
 | 
			
		||||
[](https://cloud.sealos.io/?openapp=system-fastdeploy?templateName=one-api)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
</div>
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
<summary><strong>Zeabur へのデプロイ</strong></summary>
 | 
			
		||||
<div>
 | 
			
		||||
 | 
			
		||||
> Zeabur のサーバーは海外にあるため、ネットワークの問題は自動的に解決されます。
 | 
			
		||||
 | 
			
		||||
1. まず、コードをフォークする。
 | 
			
		||||
2. [Zeabur](https://zeabur.com?referralCode=songquanpeng) にアクセスしてログインし、コンソールに入る。
 | 
			
		||||
3. 新しいプロジェクトを作成します。Service -> Add ServiceでMarketplace を選択し、MySQL を選択する。接続パラメータ(ユーザー名、パスワード、アドレス、ポート)をメモします。
 | 
			
		||||
4. 接続パラメータをコピーし、```create database `one-api` ``` を実行してデータベースを作成する。
 | 
			
		||||
5. その後、Service -> Add Service で Git を選択し(最初の使用には認証が必要です)、フォークしたリポジトリを選択します。
 | 
			
		||||
6. 自動デプロイが開始されますが、一旦キャンセルしてください。Variable タブで `PORT` に `3000` を追加し、`SQL_DSN` に `<username>:<password>@tcp(<addr>:<port>)/one-api` を追加します。変更を保存する。SQL_DSN` が設定されていないと、データが永続化されず、再デプロイ後にデータが失われるので注意すること。
 | 
			
		||||
7. 再デプロイを選択します。
 | 
			
		||||
8. Domains タブで、"my-one-api" のような適切なドメイン名の接頭辞を選択する。最終的なドメイン名は "my-one-api.zeabur.app" となります。独自のドメイン名を CNAME することもできます。
 | 
			
		||||
9. デプロイが完了するのを待ち、生成されたドメイン名をクリックして One API にアクセスします。
 | 
			
		||||
 | 
			
		||||
</div>
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
## コンフィグ
 | 
			
		||||
システムは箱から出してすぐに使えます。
 | 
			
		||||
 | 
			
		||||
環境変数やコマンドラインパラメータを設定することで、システムを構成することができます。
 | 
			
		||||
 | 
			
		||||
システム起動後、`root` ユーザーとしてログインし、さらにシステムを設定します。
 | 
			
		||||
 | 
			
		||||
## 使用方法
 | 
			
		||||
`Channels` ページで API Key を追加し、`Tokens` ページでアクセストークンを追加する。
 | 
			
		||||
 | 
			
		||||
アクセストークンを使って One API にアクセスすることができる。使い方は [OpenAI API](https://platform.openai.com/docs/api-reference/introduction) と同じです。
 | 
			
		||||
 | 
			
		||||
OpenAI API が使用されている場所では、API Base に One API のデプロイアドレスを設定することを忘れないでください(例: `https://openai.justsong.cn`)。API Key は One API で生成されたトークンでなければなりません。
 | 
			
		||||
 | 
			
		||||
具体的な API Base のフォーマットは、使用しているクライアントに依存することに注意してください。
 | 
			
		||||
 | 
			
		||||
```mermaid
 | 
			
		||||
graph LR
 | 
			
		||||
    A(ユーザ)
 | 
			
		||||
    A --->|リクエスト| B(One API)
 | 
			
		||||
    B -->|中継リクエスト| C(OpenAI)
 | 
			
		||||
    B -->|中継リクエスト| D(Azure)
 | 
			
		||||
    B -->|中継リクエスト| E(その他のダウンストリームチャンネル)
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
現在のリクエストにどのチャネルを使うかを指定するには、トークンの後に チャネル ID を追加します: 例えば、`Authorization: Bearer ONE_API_KEY-CHANNEL_ID` のようにします。
 | 
			
		||||
チャンネル ID を指定するためには、トークンは管理者によって作成される必要があることに注意してください。
 | 
			
		||||
 | 
			
		||||
もしチャネル ID が指定されない場合、ロードバランシングによってリクエストが複数のチャネルに振り分けられます。
 | 
			
		||||
 | 
			
		||||
### 環境変数
 | 
			
		||||
1. `REDIS_CONN_STRING`: 設定すると、リクエストレート制限のためのストレージとして、メモリの代わりに Redis が使われる。
 | 
			
		||||
    + 例: `REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
2. `SESSION_SECRET`: 設定すると、固定セッションキーが使用され、システムの再起動後もログインユーザーのクッキーが有効であることが保証されます。
 | 
			
		||||
    + 例: `SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`: 設定すると、SQLite の代わりに指定したデータベースが使用されます。MySQL バージョン 8.0 を使用してください。
 | 
			
		||||
    + 例: `SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
 | 
			
		||||
4. `FRONTEND_BASE_URL`: 設定されると、バックエンドアドレスではなく、指定されたフロントエンドアドレスが使われる。
 | 
			
		||||
    + 例: `FRONTEND_BASE_URL=https://openai.justsong.cn`
 | 
			
		||||
5. `SYNC_FREQUENCY`: 設定された場合、システムは定期的にデータベースからコンフィグを秒単位で同期する。設定されていない場合、同期は行われません。
 | 
			
		||||
    + 例: `SYNC_FREQUENCY=60`
 | 
			
		||||
6. `NODE_TYPE`: 設定すると、ノードのタイプを指定する。有効な値は `master` と `slave` である。設定されていない場合、デフォルトは `master`。
 | 
			
		||||
    + 例: `NODE_TYPE=slave`
 | 
			
		||||
7. `CHANNEL_UPDATE_FREQUENCY`: 設定すると、チャンネル残高を分単位で定期的に更新する。設定されていない場合、更新は行われません。
 | 
			
		||||
    + 例: `CHANNEL_UPDATE_FREQUENCY=1440`
 | 
			
		||||
8. `CHANNEL_TEST_FREQUENCY`: 設定すると、チャンネルを定期的にテストする。設定されていない場合、テストは行われません。
 | 
			
		||||
    + 例: `CHANNEL_TEST_FREQUENCY=1440`
 | 
			
		||||
9. `POLLING_INTERVAL`: チャネル残高の更新とチャネルの可用性をテストするときのリクエスト間の時間間隔 (秒)。デフォルトは間隔なし。
 | 
			
		||||
    + 例: `POLLING_INTERVAL=5`
 | 
			
		||||
 | 
			
		||||
### コマンドラインパラメータ
 | 
			
		||||
1. `--port <port_number>`: サーバがリッスンするポート番号を指定。デフォルトは `3000` です。
 | 
			
		||||
    + 例: `--port 3000`
 | 
			
		||||
2. `--log-dir <log_dir>`: ログディレクトリを指定。設定しない場合、ログは保存されません。
 | 
			
		||||
    + 例: `--log-dir ./logs`
 | 
			
		||||
3. `--version`: システムのバージョン番号を表示して終了する。
 | 
			
		||||
4. `--help`: コマンドの使用法ヘルプとパラメータの説明を表示。
 | 
			
		||||
 | 
			
		||||
## スクリーンショット
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
## FAQ
 | 
			
		||||
1. ノルマとは何か?どのように計算されますか?One API にはノルマ計算の問題はありますか?
 | 
			
		||||
    + ノルマ = グループ倍率 * モデル倍率 * (プロンプトトークンの数 + 完了トークンの数 * 完了倍率)
 | 
			
		||||
    + 完了倍率は、公式の定義と一致するように、GPT3.5 では 1.33、GPT4 では 2 に固定されています。
 | 
			
		||||
    + ストリームモードでない場合、公式 API は消費したトークンの総数を返す。ただし、プロンプトとコンプリートの消費倍率は異なるので注意してください。
 | 
			
		||||
2. アカウント残高は十分なのに、"insufficient quota" と表示されるのはなぜですか?
 | 
			
		||||
    + トークンのクォータが十分かどうかご確認ください。トークンクォータはアカウント残高とは別のものです。
 | 
			
		||||
    + トークンクォータは最大使用量を設定するためのもので、ユーザーが自由に設定できます。
 | 
			
		||||
3. チャンネルを使おうとすると "No available channels" と表示されます。どうすればいいですか?
 | 
			
		||||
    + ユーザーとチャンネルグループの設定を確認してください。
 | 
			
		||||
    + チャンネルモデルの設定も確認してください。
 | 
			
		||||
4. チャンネルテストがエラーを報告する: "invalid character '<' looking for beginning of value"
 | 
			
		||||
    + このエラーは、返された値が有効な JSON ではなく、HTML ページである場合に発生する。
 | 
			
		||||
    + ほとんどの場合、デプロイサイトのIPかプロキシのノードが CloudFlare によってブロックされています。
 | 
			
		||||
5. ChatGPT Next Web でエラーが発生しました: "Failed to fetch"
 | 
			
		||||
    + デプロイ時に `BASE_URL` を設定しないでください。
 | 
			
		||||
    + インターフェイスアドレスと API Key が正しいか再確認してください。
 | 
			
		||||
 | 
			
		||||
## 関連プロジェクト
 | 
			
		||||
[FastGPT](https://github.com/labring/FastGPT): LLM に基づく知識質問応答システム
 | 
			
		||||
 | 
			
		||||
## 注
 | 
			
		||||
本プロジェクトはオープンソースプロジェクトです。OpenAI の[利用規約](https://openai.com/policies/terms-of-use)および**適用される法令**を遵守してご利用ください。違法な目的での利用はご遠慮ください。
 | 
			
		||||
 | 
			
		||||
このプロジェクトは MIT ライセンスで公開されています。これに基づき、ページの最下部に帰属表示と本プロジェクトへのリンクを含める必要があります。
 | 
			
		||||
 | 
			
		||||
このプロジェクトを基にした派生プロジェクトについても同様です。
 | 
			
		||||
 | 
			
		||||
帰属表示を含めたくない場合は、事前に許可を得なければなりません。
 | 
			
		||||
 | 
			
		||||
MIT ライセンスによると、このプロジェクトを利用するリスクと責任は利用者が負うべきであり、このオープンソースプロジェクトの開発者は責任を負いません。
 | 
			
		||||
							
								
								
									
										37
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								README.md
									
									
									
									
									
								
							@@ -1,5 +1,5 @@
 | 
			
		||||
<p align="right">
 | 
			
		||||
   <strong>中文</strong> | <a href="./README.en.md">English</a>
 | 
			
		||||
   <strong>中文</strong> | <a href="./README.en.md">English</a> | <a href="./README.ja.md">日本語</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -51,11 +51,13 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
  <a href="https://iamazing.cn/page/reward">赞赏支持</a>
 | 
			
		||||
</p>
 | 
			
		||||
 | 
			
		||||
> **Note**:本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
> **Note**
 | 
			
		||||
> 本项目为开源项目,使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。
 | 
			
		||||
> 
 | 
			
		||||
> 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。
 | 
			
		||||
 | 
			
		||||
> **Note**:使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
 | 
			
		||||
 | 
			
		||||
> **Warning**:从 `v0.3` 版本升级到 `v0.4` 版本需要手动迁移数据库,请手动执行[数据库迁移脚本](./bin/migration_v0.3-v0.4.sql)。
 | 
			
		||||
> **Warning**
 | 
			
		||||
> 使用 Docker 拉取的最新镜像可能是 `alpha` 版本,如果追求稳定性请手动指定版本。
 | 
			
		||||
 | 
			
		||||
## 功能
 | 
			
		||||
1. 支持多种大模型:
 | 
			
		||||
@@ -66,6 +68,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
   + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html)
 | 
			
		||||
   + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html)
 | 
			
		||||
   + [x] [智谱 ChatGLM 系列模型](https://bigmodel.cn)
 | 
			
		||||
   + [x] [360 智脑](https://ai.360.cn)
 | 
			
		||||
2. 支持配置镜像以及众多第三方代理服务:
 | 
			
		||||
   + [x] [OpenAI-SB](https://openai-sb.com)
 | 
			
		||||
   + [x] [API2D](https://api2d.com/r/197971)
 | 
			
		||||
@@ -106,6 +109,8 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 
 | 
			
		||||
 | 
			
		||||
数据将会保存在宿主机的 `/home/ubuntu/data/one-api` 目录,请确保该目录存在且具有写入权限,或者更改为合适的目录。
 | 
			
		||||
 | 
			
		||||
如果启动失败,请添加 `--privileged=true`,具体参考 https://github.com/songquanpeng/one-api/issues/482 。
 | 
			
		||||
 | 
			
		||||
如果上面的镜像无法拉取,可以尝试使用 GitHub 的 Docker 镜像,将上面的 `justsong/one-api` 替换为 `ghcr.io/songquanpeng/one-api` 即可。
 | 
			
		||||
 | 
			
		||||
如果你的并发量较大,**务必**设置 `SQL_DSN`,详见下面[环境变量](#环境变量)一节。
 | 
			
		||||
@@ -272,18 +277,21 @@ graph LR
 | 
			
		||||
不加的话将会使用负载均衡的方式使用多个渠道。
 | 
			
		||||
 | 
			
		||||
### 环境变量
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为请求频率限制的存储,而非使用内存存储。
 | 
			
		||||
1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。
 | 
			
		||||
   + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153`
 | 
			
		||||
   + 如果数据库访问延迟很低,没有必要启用 Redis,启用后反而会出现数据滞后的问题。
 | 
			
		||||
2. `SESSION_SECRET`:设置之后将使用固定的会话密钥,这样系统重新启动后已登录用户的 cookie 将依旧有效。
 | 
			
		||||
   + 例子:`SESSION_SECRET=random_string`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 8.0 版本。
 | 
			
		||||
   + 例子:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
 | 
			
		||||
3. `SQL_DSN`:设置之后将使用指定数据库而非 SQLite,请使用 MySQL 或 PostgreSQL。
 | 
			
		||||
   + 例子:
 | 
			
		||||
     + MySQL:`SQL_DSN=root:123456@tcp(localhost:3306)/oneapi`
 | 
			
		||||
     + PostgreSQL:`SQL_DSN=postgres://postgres:123456@localhost:5432/oneapi`(适配中,欢迎反馈)
 | 
			
		||||
   + 注意需要提前建立数据库 `oneapi`,无需手动建表,程序将自动建表。
 | 
			
		||||
   + 如果使用本地数据库:部署命令可添加 `--network="host"` 以使得容器内的程序可以访问到宿主机上的 MySQL。
 | 
			
		||||
   + 如果使用云数据库:如果云服务器需要验证身份,需要在连接参数中添加 `?tls=skip-verify`。
 | 
			
		||||
   + 请根据你的数据库配置修改下列参数(或者保持默认值):
 | 
			
		||||
     + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `10`。
 | 
			
		||||
     + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `100`。
 | 
			
		||||
     + `SQL_MAX_IDLE_CONNS`:最大空闲连接数,默认为 `100`。
 | 
			
		||||
     + `SQL_MAX_OPEN_CONNS`:最大打开连接数,默认为 `1000`。
 | 
			
		||||
       + 如果报错 `Error 1040: Too many connections`,请适当减小该值。
 | 
			
		||||
     + `SQL_CONN_MAX_LIFETIME`:连接的最大生命周期,默认为 `60`,单位分钟。
 | 
			
		||||
4. `FRONTEND_BASE_URL`:设置之后将重定向页面请求到指定的地址,仅限从服务器设置。
 | 
			
		||||
@@ -298,6 +306,14 @@ graph LR
 | 
			
		||||
   + 例子:`CHANNEL_TEST_FREQUENCY=1440`
 | 
			
		||||
9. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
 | 
			
		||||
   + 例子:`POLLING_INTERVAL=5`
 | 
			
		||||
10. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_ENABLED=true`
 | 
			
		||||
    + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
 | 
			
		||||
11. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
 | 
			
		||||
    + 例子:`BATCH_UPDATE_INTERVAL=5`
 | 
			
		||||
12. 请求频率限制:
 | 
			
		||||
    + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
 | 
			
		||||
    + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
 | 
			
		||||
 | 
			
		||||
### 命令行参数
 | 
			
		||||
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。
 | 
			
		||||
@@ -334,6 +350,7 @@ https://openai.justsong.cn
 | 
			
		||||
5. ChatGPT Next Web 报错:`Failed to fetch`
 | 
			
		||||
   + 部署的时候不要设置 `BASE_URL`。
 | 
			
		||||
   + 检查你的接口地址和 API Key 有没有填对。
 | 
			
		||||
   + 检查是否启用了 HTTPS,浏览器会拦截 HTTPS 域名下的 HTTP 请求。
 | 
			
		||||
6. 报错:`当前分组负载已饱和,请稍后再试`
 | 
			
		||||
   + 上游通道 429 了。
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -94,6 +94,9 @@ var RequestInterval = time.Duration(requestInterval) * time.Second
 | 
			
		||||
 | 
			
		||||
var SyncFrequency = 10 * 60 // unit is second, will be overwritten by SYNC_FREQUENCY
 | 
			
		||||
 | 
			
		||||
var BatchUpdateEnabled = false
 | 
			
		||||
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	RoleGuestUser  = 0
 | 
			
		||||
	RoleCommonUser = 1
 | 
			
		||||
@@ -111,10 +114,10 @@ var (
 | 
			
		||||
// All duration's unit is seconds
 | 
			
		||||
// Shouldn't larger then RateLimitKeyExpirationDuration
 | 
			
		||||
var (
 | 
			
		||||
	GlobalApiRateLimitNum            = 180
 | 
			
		||||
	GlobalApiRateLimitNum            = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
 | 
			
		||||
	GlobalApiRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	GlobalWebRateLimitNum            = 60
 | 
			
		||||
	GlobalWebRateLimitNum            = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60)
 | 
			
		||||
	GlobalWebRateLimitDuration int64 = 3 * 60
 | 
			
		||||
 | 
			
		||||
	UploadRateLimitNum            = 10
 | 
			
		||||
@@ -154,45 +157,53 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ChannelTypeUnknown   = 0
 | 
			
		||||
	ChannelTypeOpenAI    = 1
 | 
			
		||||
	ChannelTypeAPI2D     = 2
 | 
			
		||||
	ChannelTypeAzure     = 3
 | 
			
		||||
	ChannelTypeCloseAI   = 4
 | 
			
		||||
	ChannelTypeOpenAISB  = 5
 | 
			
		||||
	ChannelTypeOpenAIMax = 6
 | 
			
		||||
	ChannelTypeOhMyGPT   = 7
 | 
			
		||||
	ChannelTypeCustom    = 8
 | 
			
		||||
	ChannelTypeAILS      = 9
 | 
			
		||||
	ChannelTypeAIProxy   = 10
 | 
			
		||||
	ChannelTypePaLM      = 11
 | 
			
		||||
	ChannelTypeAPI2GPT   = 12
 | 
			
		||||
	ChannelTypeAIGC2D    = 13
 | 
			
		||||
	ChannelTypeAnthropic = 14
 | 
			
		||||
	ChannelTypeBaidu     = 15
 | 
			
		||||
	ChannelTypeZhipu     = 16
 | 
			
		||||
	ChannelTypeAli       = 17
 | 
			
		||||
	ChannelTypeXunfei    = 18
 | 
			
		||||
	ChannelTypeUnknown        = 0
 | 
			
		||||
	ChannelTypeOpenAI         = 1
 | 
			
		||||
	ChannelTypeAPI2D          = 2
 | 
			
		||||
	ChannelTypeAzure          = 3
 | 
			
		||||
	ChannelTypeCloseAI        = 4
 | 
			
		||||
	ChannelTypeOpenAISB       = 5
 | 
			
		||||
	ChannelTypeOpenAIMax      = 6
 | 
			
		||||
	ChannelTypeOhMyGPT        = 7
 | 
			
		||||
	ChannelTypeCustom         = 8
 | 
			
		||||
	ChannelTypeAILS           = 9
 | 
			
		||||
	ChannelTypeAIProxy        = 10
 | 
			
		||||
	ChannelTypePaLM           = 11
 | 
			
		||||
	ChannelTypeAPI2GPT        = 12
 | 
			
		||||
	ChannelTypeAIGC2D         = 13
 | 
			
		||||
	ChannelTypeAnthropic      = 14
 | 
			
		||||
	ChannelTypeBaidu          = 15
 | 
			
		||||
	ChannelTypeZhipu          = 16
 | 
			
		||||
	ChannelTypeAli            = 17
 | 
			
		||||
	ChannelTypeXunfei         = 18
 | 
			
		||||
	ChannelType360            = 19
 | 
			
		||||
	ChannelTypeOpenRouter     = 20
 | 
			
		||||
	ChannelTypeAIProxyLibrary = 21
 | 
			
		||||
	ChannelTypeFastGPT        = 22
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChannelBaseURLs = []string{
 | 
			
		||||
	"",                               // 0
 | 
			
		||||
	"https://api.openai.com",         // 1
 | 
			
		||||
	"https://oa.api2d.net",           // 2
 | 
			
		||||
	"",                               // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz",  // 4
 | 
			
		||||
	"https://api.openai-sb.com",      // 5
 | 
			
		||||
	"https://api.openaimax.com",      // 6
 | 
			
		||||
	"https://api.ohmygpt.com",        // 7
 | 
			
		||||
	"",                               // 8
 | 
			
		||||
	"https://api.caipacity.com",      // 9
 | 
			
		||||
	"https://api.aiproxy.io",         // 10
 | 
			
		||||
	"",                               // 11
 | 
			
		||||
	"https://api.api2gpt.com",        // 12
 | 
			
		||||
	"https://api.aigc2d.com",         // 13
 | 
			
		||||
	"https://api.anthropic.com",      // 14
 | 
			
		||||
	"https://aip.baidubce.com",       // 15
 | 
			
		||||
	"https://open.bigmodel.cn",       // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com", // 17
 | 
			
		||||
	"",                               // 18
 | 
			
		||||
	"",                                // 0
 | 
			
		||||
	"https://api.openai.com",          // 1
 | 
			
		||||
	"https://oa.api2d.net",            // 2
 | 
			
		||||
	"",                                // 3
 | 
			
		||||
	"https://api.closeai-proxy.xyz",   // 4
 | 
			
		||||
	"https://api.openai-sb.com",       // 5
 | 
			
		||||
	"https://api.openaimax.com",       // 6
 | 
			
		||||
	"https://api.ohmygpt.com",         // 7
 | 
			
		||||
	"",                                // 8
 | 
			
		||||
	"https://api.caipacity.com",       // 9
 | 
			
		||||
	"https://api.aiproxy.io",          // 10
 | 
			
		||||
	"",                                // 11
 | 
			
		||||
	"https://api.api2gpt.com",         // 12
 | 
			
		||||
	"https://api.aigc2d.com",          // 13
 | 
			
		||||
	"https://api.anthropic.com",       // 14
 | 
			
		||||
	"https://aip.baidubce.com",        // 15
 | 
			
		||||
	"https://open.bigmodel.cn",        // 16
 | 
			
		||||
	"https://dashscope.aliyuncs.com",  // 17
 | 
			
		||||
	"",                                // 18
 | 
			
		||||
	"https://ai.360.cn",               // 19
 | 
			
		||||
	"https://openrouter.ai/api",       // 20
 | 
			
		||||
	"https://api.aiproxy.io",          // 21
 | 
			
		||||
	"https://fastgpt.run/api/openapi", // 22
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,9 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import "encoding/json"
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ModelRatio
 | 
			
		||||
// https://platform.openai.com/docs/models/model-endpoint-compatibility
 | 
			
		||||
@@ -10,46 +13,52 @@ import "encoding/json"
 | 
			
		||||
// 1 === $0.002 / 1K tokens
 | 
			
		||||
// 1 === ¥0.014 / 1k tokens
 | 
			
		||||
var ModelRatio = map[string]float64{
 | 
			
		||||
	"gpt-4":                   15,
 | 
			
		||||
	"gpt-4-0314":              15,
 | 
			
		||||
	"gpt-4-0613":              15,
 | 
			
		||||
	"gpt-4-32k":               30,
 | 
			
		||||
	"gpt-4-32k-0314":          30,
 | 
			
		||||
	"gpt-4-32k-0613":          30,
 | 
			
		||||
	"gpt-3.5-turbo":           0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":      0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":       1.5, // $0.003 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-16k-0613":  1.5,
 | 
			
		||||
	"text-ada-001":            0.2,
 | 
			
		||||
	"text-babbage-001":        0.25,
 | 
			
		||||
	"text-curie-001":          1,
 | 
			
		||||
	"text-davinci-002":        10,
 | 
			
		||||
	"text-davinci-003":        10,
 | 
			
		||||
	"text-davinci-edit-001":   10,
 | 
			
		||||
	"code-davinci-edit-001":   10,
 | 
			
		||||
	"whisper-1":               10,
 | 
			
		||||
	"davinci":                 10,
 | 
			
		||||
	"curie":                   10,
 | 
			
		||||
	"babbage":                 10,
 | 
			
		||||
	"ada":                     10,
 | 
			
		||||
	"text-embedding-ada-002":  0.05,
 | 
			
		||||
	"text-search-ada-doc-001": 10,
 | 
			
		||||
	"text-moderation-stable":  0.1,
 | 
			
		||||
	"text-moderation-latest":  0.1,
 | 
			
		||||
	"dall-e":                  8,
 | 
			
		||||
	"claude-instant-1":        0.75,
 | 
			
		||||
	"claude-2":                30,
 | 
			
		||||
	"ERNIE-Bot":               0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":         0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                  1,
 | 
			
		||||
	"chatglm_pro":             0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":            0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                 0.8572, // TBD: https://help.aliyun.com/document_detail/2399482.html?spm=a2c4g.2399482.0.0.1ad347feilAgag
 | 
			
		||||
	"qwen-plus-v1":            0.5715, // Same as above
 | 
			
		||||
	"SparkDesk":               0.8572, // TBD
 | 
			
		||||
	"gpt-4":                     15,
 | 
			
		||||
	"gpt-4-0314":                15,
 | 
			
		||||
	"gpt-4-0613":                15,
 | 
			
		||||
	"gpt-4-32k":                 30,
 | 
			
		||||
	"gpt-4-32k-0314":            30,
 | 
			
		||||
	"gpt-4-32k-0613":            30,
 | 
			
		||||
	"gpt-3.5-turbo":             0.75, // $0.0015 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-0301":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-0613":        0.75,
 | 
			
		||||
	"gpt-3.5-turbo-16k":         1.5, // $0.003 / 1K tokens
 | 
			
		||||
	"gpt-3.5-turbo-16k-0613":    1.5,
 | 
			
		||||
	"text-ada-001":              0.2,
 | 
			
		||||
	"text-babbage-001":          0.25,
 | 
			
		||||
	"text-curie-001":            1,
 | 
			
		||||
	"text-davinci-002":          10,
 | 
			
		||||
	"text-davinci-003":          10,
 | 
			
		||||
	"text-davinci-edit-001":     10,
 | 
			
		||||
	"code-davinci-edit-001":     10,
 | 
			
		||||
	"whisper-1":                 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
 | 
			
		||||
	"davinci":                   10,
 | 
			
		||||
	"curie":                     10,
 | 
			
		||||
	"babbage":                   10,
 | 
			
		||||
	"ada":                       10,
 | 
			
		||||
	"text-embedding-ada-002":    0.05,
 | 
			
		||||
	"text-search-ada-doc-001":   10,
 | 
			
		||||
	"text-moderation-stable":    0.1,
 | 
			
		||||
	"text-moderation-latest":    0.1,
 | 
			
		||||
	"dall-e":                    8,
 | 
			
		||||
	"claude-instant-1":          0.815,  // $1.63 / 1M tokens
 | 
			
		||||
	"claude-2":                  5.51,   // $11.02 / 1M tokens
 | 
			
		||||
	"ERNIE-Bot":                 0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"ERNIE-Bot-turbo":           0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"Embedding-V1":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"PaLM-2":                    1,
 | 
			
		||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"qwen-v1":                   0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"qwen-plus-v1":              1,      // ¥0.014 / 1k tokens
 | 
			
		||||
	"text-embedding-v1":         0.05,   // ¥0.0007 / 1k tokens
 | 
			
		||||
	"SparkDesk":                 1.2858, // ¥0.018 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9":              0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
	"embedding-bert-512-v1":     0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"embedding_s1_v1":           0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens
 | 
			
		||||
	"360GPT_S2_V9.4":            0.8572, // ¥0.012 / 1k tokens
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ModelRatio2JSONString() string {
 | 
			
		||||
@@ -73,3 +82,19 @@ func GetModelRatio(name string) float64 {
 | 
			
		||||
	}
 | 
			
		||||
	return ratio
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetCompletionRatio(name string) float64 {
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-3.5") {
 | 
			
		||||
		return 1.333333
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "gpt-4") {
 | 
			
		||||
		return 2
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "claude-instant-1") {
 | 
			
		||||
		return 3.38
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(name, "claude-2") {
 | 
			
		||||
		return 2.965517
 | 
			
		||||
	}
 | 
			
		||||
	return 1
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -61,3 +61,8 @@ func RedisDel(key string) error {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	return RDB.Del(ctx, key).Err()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RedisDecrease(key string, value int64) error {
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	return RDB.DecrBy(ctx, key, value).Err()
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,7 @@ import (
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIError) {
 | 
			
		||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
 | 
			
		||||
	switch channel.Type {
 | 
			
		||||
	case common.ChannelTypePaLM:
 | 
			
		||||
		fallthrough
 | 
			
		||||
@@ -24,10 +24,19 @@ func testChannel(channel *model.Channel, request ChatRequest) (error, *OpenAIErr
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeZhipu:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeAli:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelType360:
 | 
			
		||||
		fallthrough
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
 | 
			
		||||
	case common.ChannelTypeAzure:
 | 
			
		||||
		request.Model = "gpt-35-turbo"
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	default:
 | 
			
		||||
		request.Model = "gpt-3.5-turbo"
 | 
			
		||||
	}
 | 
			
		||||
@@ -174,7 +183,7 @@ func testAllChannels(notify bool) error {
 | 
			
		||||
				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if shouldDisableChannel(openaiErr) {
 | 
			
		||||
			if shouldDisableChannel(openaiErr, -1) {
 | 
			
		||||
				disableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			channel.UpdateResponseTime(milliseconds)
 | 
			
		||||
 
 | 
			
		||||
@@ -85,7 +85,7 @@ func AddChannel(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	channel.CreatedTime = common.GetTimestamp()
 | 
			
		||||
	keys := strings.Split(channel.Key, "\n")
 | 
			
		||||
	channels := make([]model.Channel, 0)
 | 
			
		||||
	channels := make([]model.Channel, 0, len(keys))
 | 
			
		||||
	for _, key := range keys {
 | 
			
		||||
		if key == "" {
 | 
			
		||||
			continue
 | 
			
		||||
 
 | 
			
		||||
@@ -63,6 +63,15 @@ func init() {
 | 
			
		||||
			Root:       "dall-e",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "whisper-1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "openai",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "whisper-1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "gpt-3.5-turbo",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -351,6 +360,15 @@ func init() {
 | 
			
		||||
			Root:       "qwen-plus-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "text-embedding-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "ali",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "text-embedding-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "SparkDesk",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
@@ -360,6 +378,51 @@ func init() {
 | 
			
		||||
			Root:       "SparkDesk",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "360GPT_S2_V9",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "360GPT_S2_V9",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "embedding-bert-512-v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "embedding-bert-512-v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "embedding_s1_v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "embedding_s1_v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "semantic_similarity_s1_v1",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "semantic_similarity_s1_v1",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			Id:         "360GPT_S2_V9.4",
 | 
			
		||||
			Object:     "model",
 | 
			
		||||
			Created:    1677649963,
 | 
			
		||||
			OwnedBy:    "360",
 | 
			
		||||
			Permission: permission,
 | 
			
		||||
			Root:       "360GPT_S2_V9.4",
 | 
			
		||||
			Parent:     nil,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	openAIModelsMap = make(map[string]OpenAIModels)
 | 
			
		||||
	for _, model := range openAIModels {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										220
									
								
								controller/relay-aiproxy.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,220 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://docs.aiproxy.io/dev/library#使用已经定制好的知识库进行对话问答
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryRequest struct {
 | 
			
		||||
	Model     string `json:"model"`
 | 
			
		||||
	Query     string `json:"query"`
 | 
			
		||||
	LibraryId string `json:"libraryId"`
 | 
			
		||||
	Stream    bool   `json:"stream"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryError struct {
 | 
			
		||||
	ErrCode int    `json:"errCode"`
 | 
			
		||||
	Message string `json:"message"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryDocument struct {
 | 
			
		||||
	Title string `json:"title"`
 | 
			
		||||
	URL   string `json:"url"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryResponse struct {
 | 
			
		||||
	Success   bool                     `json:"success"`
 | 
			
		||||
	Answer    string                   `json:"answer"`
 | 
			
		||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
			
		||||
	AIProxyLibraryError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AIProxyLibraryStreamResponse struct {
 | 
			
		||||
	Content   string                   `json:"content"`
 | 
			
		||||
	Finish    bool                     `json:"finish"`
 | 
			
		||||
	Model     string                   `json:"model"`
 | 
			
		||||
	Documents []AIProxyLibraryDocument `json:"documents"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest {
 | 
			
		||||
	query := ""
 | 
			
		||||
	if len(request.Messages) != 0 {
 | 
			
		||||
		query = request.Messages[len(request.Messages)-1].Content
 | 
			
		||||
	}
 | 
			
		||||
	return &AIProxyLibraryRequest{
 | 
			
		||||
		Model:  request.Model,
 | 
			
		||||
		Stream: request.Stream,
 | 
			
		||||
		Query:  query,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string {
 | 
			
		||||
	if len(documents) == 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
	content := "\n\n参考文档:\n"
 | 
			
		||||
	for i, document := range documents {
 | 
			
		||||
		content += fmt.Sprintf("%d. [%s](%s)\n", i+1, document.Title, document.URL)
 | 
			
		||||
	}
 | 
			
		||||
	return content
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse {
 | 
			
		||||
	content := response.Answer + aiProxyDocuments2Markdown(response.Documents)
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: content,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: "stop",
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := OpenAITextResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Choices: []OpenAITextResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
	return &fullTextResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func documentsAIProxyLibrary(documents []AIProxyLibraryDocument) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = aiProxyDocuments2Markdown(documents)
 | 
			
		||||
	choice.FinishReason = &stopFinishReason
 | 
			
		||||
	return &ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   "",
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func streamResponseAIProxyLibrary2OpenAI(response *AIProxyLibraryStreamResponse) *ChatCompletionsStreamResponse {
 | 
			
		||||
	var choice ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = response.Content
 | 
			
		||||
	return &ChatCompletionsStreamResponse{
 | 
			
		||||
		Id:      common.GetUUID(),
 | 
			
		||||
		Object:  "chat.completion.chunk",
 | 
			
		||||
		Created: common.GetTimestamp(),
 | 
			
		||||
		Model:   response.Model,
 | 
			
		||||
		Choices: []ChatCompletionsStreamResponseChoice{choice},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyLibraryStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		if i := strings.Index(string(data), "\n"); i >= 0 {
 | 
			
		||||
			return i + 1, data[0:i], nil
 | 
			
		||||
		}
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
	dataChan := make(chan string)
 | 
			
		||||
	stopChan := make(chan bool)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 5 { // ignore blank line or wrong format
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if data[:5] != "data:" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			data = data[5:]
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	var documents []AIProxyLibraryDocument
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
			var AIProxyLibraryResponse AIProxyLibraryStreamResponse
 | 
			
		||||
			err := json.Unmarshal([]byte(data), &AIProxyLibraryResponse)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			if len(AIProxyLibraryResponse.Documents) != 0 {
 | 
			
		||||
				documents = AIProxyLibraryResponse.Documents
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			return true
 | 
			
		||||
		case <-stopChan:
 | 
			
		||||
			response := documentsAIProxyLibrary(documents)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
 | 
			
		||||
			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
 | 
			
		||||
			return false
 | 
			
		||||
		}
 | 
			
		||||
	})
 | 
			
		||||
	err := resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aiProxyLibraryHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var AIProxyLibraryResponse AIProxyLibraryResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &AIProxyLibraryResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	if AIProxyLibraryResponse.ErrCode != 0 {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: AIProxyLibraryResponse.Message,
 | 
			
		||||
				Type:    strconv.Itoa(AIProxyLibraryResponse.ErrCode),
 | 
			
		||||
				Code:    AIProxyLibraryResponse.ErrCode,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	fullTextResponse := responseAIProxyLibrary2OpenAI(&AIProxyLibraryResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
@@ -35,6 +35,29 @@ type AliChatRequest struct {
 | 
			
		||||
	Parameters AliParameters `json:"parameters,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbeddingRequest struct {
 | 
			
		||||
	Model string `json:"model"`
 | 
			
		||||
	Input struct {
 | 
			
		||||
		Texts []string `json:"texts"`
 | 
			
		||||
	} `json:"input"`
 | 
			
		||||
	Parameters *struct {
 | 
			
		||||
		TextType string `json:"text_type,omitempty"`
 | 
			
		||||
	} `json:"parameters,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbedding struct {
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
	TextIndex int       `json:"text_index"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliEmbeddingResponse struct {
 | 
			
		||||
	Output struct {
 | 
			
		||||
		Embeddings []AliEmbedding `json:"embeddings"`
 | 
			
		||||
	} `json:"output"`
 | 
			
		||||
	Usage AliUsage `json:"usage"`
 | 
			
		||||
	AliError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliError struct {
 | 
			
		||||
	Code      string `json:"code"`
 | 
			
		||||
	Message   string `json:"message"`
 | 
			
		||||
@@ -44,6 +67,7 @@ type AliError struct {
 | 
			
		||||
type AliUsage struct {
 | 
			
		||||
	InputTokens  int `json:"input_tokens"`
 | 
			
		||||
	OutputTokens int `json:"output_tokens"`
 | 
			
		||||
	TotalTokens  int `json:"total_tokens"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliOutput struct {
 | 
			
		||||
@@ -95,6 +119,70 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Ali(request GeneralOpenAIRequest) *AliEmbeddingRequest {
 | 
			
		||||
	return &AliEmbeddingRequest{
 | 
			
		||||
		Model: "text-embedding-v1",
 | 
			
		||||
		Input: struct {
 | 
			
		||||
			Texts []string `json:"texts"`
 | 
			
		||||
		}{
 | 
			
		||||
			Texts: request.ParseInput(),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var aliResponse AliEmbeddingResponse
 | 
			
		||||
	err := json.NewDecoder(resp.Body).Decode(&aliResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if aliResponse.Code != "" {
 | 
			
		||||
		return &OpenAIErrorWithStatusCode{
 | 
			
		||||
			OpenAIError: OpenAIError{
 | 
			
		||||
				Message: aliResponse.Message,
 | 
			
		||||
				Type:    aliResponse.Code,
 | 
			
		||||
				Param:   aliResponse.RequestId,
 | 
			
		||||
				Code:    aliResponse.Code,
 | 
			
		||||
			},
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullTextResponse := embeddingResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
	jsonResponse, err := json.Marshal(fullTextResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "application/json")
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
	openAIEmbeddingResponse := OpenAIEmbeddingResponse{
 | 
			
		||||
		Object: "list",
 | 
			
		||||
		Data:   make([]OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
 | 
			
		||||
		Model:  "text-embedding-v1",
 | 
			
		||||
		Usage:  Usage{TotalTokens: response.Usage.TotalTokens},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, item := range response.Output.Embeddings {
 | 
			
		||||
		openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, OpenAIEmbeddingResponseItem{
 | 
			
		||||
			Object:    `embedding`,
 | 
			
		||||
			Index:     item.TextIndex,
 | 
			
		||||
			Embedding: item.Embedding,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	return &openAIEmbeddingResponse
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse {
 | 
			
		||||
	choice := OpenAITextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
@@ -166,11 +254,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	lastResponseText := ""
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
@@ -181,9 +265,11 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStat
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage.PromptTokens += aliResponse.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens += aliResponse.Usage.OutputTokens
 | 
			
		||||
			usage.TotalTokens += aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
			if aliResponse.Usage.OutputTokens != 0 {
 | 
			
		||||
				usage.PromptTokens = aliResponse.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens = aliResponse.Usage.OutputTokens
 | 
			
		||||
				usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseAli2OpenAI(&aliResponse)
 | 
			
		||||
			response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
 | 
			
		||||
			lastResponseText = aliResponse.Output.Text
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										147
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								controller/relay-audio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,147 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func relayAudioHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	audioModel := "whisper-1"
 | 
			
		||||
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	userId := c.GetInt("id")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
 | 
			
		||||
	preConsumedTokens := common.PreConsumedQuota
 | 
			
		||||
	modelRatio := common.GetModelRatio(audioModel)
 | 
			
		||||
	groupRatio := common.GetGroupRatio(group)
 | 
			
		||||
	ratio := modelRatio * groupRatio
 | 
			
		||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	userQuota, err := model.CacheGetUserQuota(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if userQuota > 100*preConsumedQuota {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
	}
 | 
			
		||||
	if preConsumedQuota > 0 {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// map model name
 | 
			
		||||
	modelMapping := c.GetString("model_mapping")
 | 
			
		||||
	if modelMapping != "" {
 | 
			
		||||
		modelMap := make(map[string]string)
 | 
			
		||||
		err := json.Unmarshal([]byte(modelMapping), &modelMap)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		if modelMap[audioModel] != "" {
 | 
			
		||||
			audioModel = modelMap[audioModel]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
 | 
			
		||||
	if c.GetString("base_url") != "" {
 | 
			
		||||
		baseURL = c.GetString("base_url")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
 | 
			
		||||
	requestBody := c.Request.Body
 | 
			
		||||
 | 
			
		||||
	req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "new_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
	req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
	req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
 | 
			
		||||
	resp, err := httpClient.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "do_request_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = req.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = c.Request.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	var audioResponse AudioResponse
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		go func() {
 | 
			
		||||
			quota := countTokenText(audioResponse.Text, audioModel)
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, 0, 0, audioModel, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &audioResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -3,22 +3,22 @@ package controller
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
 | 
			
		||||
 | 
			
		||||
type BaiduTokenResponse struct {
 | 
			
		||||
	RefreshToken  string `json:"refresh_token"`
 | 
			
		||||
	ExpiresIn     int    `json:"expires_in"`
 | 
			
		||||
	SessionKey    string `json:"session_key"`
 | 
			
		||||
	AccessToken   string `json:"access_token"`
 | 
			
		||||
	Scope         string `json:"scope"`
 | 
			
		||||
	SessionSecret string `json:"session_secret"`
 | 
			
		||||
	ExpiresIn   int    `json:"expires_in"`
 | 
			
		||||
	AccessToken string `json:"access_token"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduMessage struct {
 | 
			
		||||
@@ -73,6 +73,16 @@ type BaiduEmbeddingResponse struct {
 | 
			
		||||
	BaiduError
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BaiduAccessToken struct {
 | 
			
		||||
	AccessToken      string    `json:"access_token"`
 | 
			
		||||
	Error            string    `json:"error,omitempty"`
 | 
			
		||||
	ErrorDescription string    `json:"error_description,omitempty"`
 | 
			
		||||
	ExpiresIn        int64     `json:"expires_in,omitempty"`
 | 
			
		||||
	ExpiresAt        time.Time `json:"-"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var baiduTokenStore sync.Map
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest {
 | 
			
		||||
	messages := make([]BaiduMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
@@ -134,16 +144,9 @@ func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *ChatCom
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingRequestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduEmbeddingRequest {
 | 
			
		||||
	baiduEmbeddingRequest := BaiduEmbeddingRequest{
 | 
			
		||||
		Input: nil,
 | 
			
		||||
	return &BaiduEmbeddingRequest{
 | 
			
		||||
		Input: request.ParseInput(),
 | 
			
		||||
	}
 | 
			
		||||
	switch request.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = []string{request.Input.(string)}
 | 
			
		||||
	case []string:
 | 
			
		||||
		baiduEmbeddingRequest.Input = request.Input.([]string)
 | 
			
		||||
	}
 | 
			
		||||
	return &baiduEmbeddingRequest
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *OpenAIEmbeddingResponse {
 | 
			
		||||
@@ -191,11 +194,7 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
@@ -205,9 +204,11 @@ func baiduStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
			
		||||
				common.SysError("error unmarshalling stream response: " + err.Error())
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
			usage.PromptTokens += baiduResponse.Usage.PromptTokens
 | 
			
		||||
			usage.CompletionTokens += baiduResponse.Usage.CompletionTokens
 | 
			
		||||
			usage.TotalTokens += baiduResponse.Usage.TotalTokens
 | 
			
		||||
			if baiduResponse.Usage.TotalTokens != 0 {
 | 
			
		||||
				usage.TotalTokens = baiduResponse.Usage.TotalTokens
 | 
			
		||||
				usage.PromptTokens = baiduResponse.Usage.PromptTokens
 | 
			
		||||
				usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
 | 
			
		||||
			}
 | 
			
		||||
			response := streamResponseBaidu2OpenAI(&baiduResponse)
 | 
			
		||||
			jsonResponse, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -299,3 +300,60 @@ func baiduEmbeddingHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWit
 | 
			
		||||
	_, err = c.Writer.Write(jsonResponse)
 | 
			
		||||
	return nil, &fullTextResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getBaiduAccessToken(apiKey string) (string, error) {
 | 
			
		||||
	if val, ok := baiduTokenStore.Load(apiKey); ok {
 | 
			
		||||
		var accessToken BaiduAccessToken
 | 
			
		||||
		if accessToken, ok = val.(BaiduAccessToken); ok {
 | 
			
		||||
			// soon this will expire
 | 
			
		||||
			if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
 | 
			
		||||
				go func() {
 | 
			
		||||
					_, _ = getBaiduAccessTokenHelper(apiKey)
 | 
			
		||||
				}()
 | 
			
		||||
			}
 | 
			
		||||
			return accessToken.AccessToken, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	accessToken, err := getBaiduAccessTokenHelper(apiKey)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	if accessToken == nil {
 | 
			
		||||
		return "", errors.New("getBaiduAccessToken return a nil token")
 | 
			
		||||
	}
 | 
			
		||||
	return (*accessToken).AccessToken, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
 | 
			
		||||
	parts := strings.Split(apiKey, "|")
 | 
			
		||||
	if len(parts) != 2 {
 | 
			
		||||
		return nil, errors.New("invalid baidu apikey")
 | 
			
		||||
	}
 | 
			
		||||
	req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
 | 
			
		||||
		parts[0], parts[1]), nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	req.Header.Add("Content-Type", "application/json")
 | 
			
		||||
	req.Header.Add("Accept", "application/json")
 | 
			
		||||
	res, err := impatientHTTPClient.Do(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer res.Body.Close()
 | 
			
		||||
 | 
			
		||||
	var accessToken BaiduAccessToken
 | 
			
		||||
	err = json.NewDecoder(res.Body).Decode(&accessToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if accessToken.Error != "" {
 | 
			
		||||
		return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
 | 
			
		||||
	}
 | 
			
		||||
	if accessToken.AccessToken == "" {
 | 
			
		||||
		return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
 | 
			
		||||
	}
 | 
			
		||||
	accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
 | 
			
		||||
	baiduTokenStore.Store(apiKey, accessToken)
 | 
			
		||||
	return &accessToken, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -141,11 +141,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithS
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
 
 | 
			
		||||
@@ -66,11 +66,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*O
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
 
 | 
			
		||||
@@ -143,11 +143,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSta
 | 
			
		||||
		dataChan <- string(jsonResponse)
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/model"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -21,12 +22,17 @@ const (
 | 
			
		||||
	APITypeZhipu
 | 
			
		||||
	APITypeAli
 | 
			
		||||
	APITypeXunfei
 | 
			
		||||
	APITypeAIProxyLibrary
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var httpClient *http.Client
 | 
			
		||||
var impatientHTTPClient *http.Client
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	httpClient = &http.Client{}
 | 
			
		||||
	impatientHTTPClient = &http.Client{
 | 
			
		||||
		Timeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
@@ -99,6 +105,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		apiType = APITypeAli
 | 
			
		||||
	case common.ChannelTypeXunfei:
 | 
			
		||||
		apiType = APITypeXunfei
 | 
			
		||||
	case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
		apiType = APITypeAIProxyLibrary
 | 
			
		||||
	}
 | 
			
		||||
	baseURL := common.ChannelBaseURLs[channelType]
 | 
			
		||||
	requestURL := c.Request.URL.String()
 | 
			
		||||
@@ -145,7 +153,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
		apiKey := c.Request.Header.Get("Authorization")
 | 
			
		||||
		apiKey = strings.TrimPrefix(apiKey, "Bearer ")
 | 
			
		||||
		fullRequestURL += "?access_token=" + apiKey // TODO: access token expire in 30 days
 | 
			
		||||
		var err error
 | 
			
		||||
		if apiKey, err = getBaiduAccessToken(apiKey); err != nil {
 | 
			
		||||
			return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		fullRequestURL += "?access_token=" + apiKey
 | 
			
		||||
	case APITypePaLM:
 | 
			
		||||
		fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
 | 
			
		||||
		if baseURL != "" {
 | 
			
		||||
@@ -162,6 +174,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method)
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
 | 
			
		||||
		if relayMode == RelayModeEmbeddings {
 | 
			
		||||
			fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL)
 | 
			
		||||
	}
 | 
			
		||||
	var promptTokens int
 | 
			
		||||
	var completionTokens int
 | 
			
		||||
@@ -185,7 +202,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if userQuota > 10*preConsumedQuota {
 | 
			
		||||
	err = model.CacheDecreaseUserQuota(userId, preConsumedQuota)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
	if userQuota > 100*preConsumedQuota {
 | 
			
		||||
		// in this case, we do not pre-consume quota
 | 
			
		||||
		// because the user has enough quota
 | 
			
		||||
		preConsumedQuota = 0
 | 
			
		||||
@@ -244,8 +265,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeAli:
 | 
			
		||||
		aliRequest := requestOpenAI2Ali(textRequest)
 | 
			
		||||
		jsonStr, err := json.Marshal(aliRequest)
 | 
			
		||||
		var jsonStr []byte
 | 
			
		||||
		var err error
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
		case RelayModeEmbeddings:
 | 
			
		||||
			aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest)
 | 
			
		||||
			jsonStr, err = json.Marshal(aliEmbeddingRequest)
 | 
			
		||||
		default:
 | 
			
		||||
			aliRequest := requestOpenAI2Ali(textRequest)
 | 
			
		||||
			jsonStr, err = json.Marshal(aliRequest)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		requestBody = bytes.NewBuffer(jsonStr)
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest)
 | 
			
		||||
		aiProxyLibraryRequest.LibraryId = c.GetString("library_id")
 | 
			
		||||
		jsonStr, err := json.Marshal(aiProxyLibraryRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
@@ -269,6 +306,10 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
				req.Header.Set("api-key", apiKey)
 | 
			
		||||
			} else {
 | 
			
		||||
				req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
 | 
			
		||||
				if channelType == common.ChannelTypeOpenRouter {
 | 
			
		||||
					req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
 | 
			
		||||
					req.Header.Set("X-Title", "One API")
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		case APITypeClaude:
 | 
			
		||||
			req.Header.Set("x-api-key", apiKey)
 | 
			
		||||
@@ -285,6 +326,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			if textRequest.Stream {
 | 
			
		||||
				req.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
			}
 | 
			
		||||
		default:
 | 
			
		||||
			req.Header.Set("Authorization", "Bearer "+apiKey)
 | 
			
		||||
		}
 | 
			
		||||
		req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
 | 
			
		||||
		req.Header.Set("Accept", c.Request.Header.Get("Accept"))
 | 
			
		||||
@@ -302,54 +345,62 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
		isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | 
			
		||||
 | 
			
		||||
		if resp.StatusCode != http.StatusOK {
 | 
			
		||||
			if preConsumedQuota != 0 {
 | 
			
		||||
				go func() {
 | 
			
		||||
					// return pre-consumed quota
 | 
			
		||||
					err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						common.SysError("error return pre-consumed quota: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
			}
 | 
			
		||||
			return relayErrorHandler(resp)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	tokenName := c.GetString("token_name")
 | 
			
		||||
	channelId := c.GetInt("channel_id")
 | 
			
		||||
 | 
			
		||||
	defer func() {
 | 
			
		||||
		c.Writer.Flush()
 | 
			
		||||
		if consumeQuota {
 | 
			
		||||
			quota := 0
 | 
			
		||||
			completionRatio := 1.0
 | 
			
		||||
			if strings.HasPrefix(textRequest.Model, "gpt-3.5") {
 | 
			
		||||
				completionRatio = 1.333333
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(textRequest.Model, "gpt-4") {
 | 
			
		||||
				completionRatio = 2
 | 
			
		||||
			}
 | 
			
		||||
		// c.Writer.Flush()
 | 
			
		||||
		go func() {
 | 
			
		||||
			if consumeQuota {
 | 
			
		||||
				quota := 0
 | 
			
		||||
				completionRatio := common.GetCompletionRatio(textRequest.Model)
 | 
			
		||||
				promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
				completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
 | 
			
		||||
			promptTokens = textResponse.Usage.PromptTokens
 | 
			
		||||
			completionTokens = textResponse.Usage.CompletionTokens
 | 
			
		||||
 | 
			
		||||
			quota = promptTokens + int(float64(completionTokens)*completionRatio)
 | 
			
		||||
			quota = int(float64(quota) * ratio)
 | 
			
		||||
			if ratio != 0 && quota <= 0 {
 | 
			
		||||
				quota = 1
 | 
			
		||||
				quota = promptTokens + int(float64(completionTokens)*completionRatio)
 | 
			
		||||
				quota = int(float64(quota) * ratio)
 | 
			
		||||
				if ratio != 0 && quota <= 0 {
 | 
			
		||||
					quota = 1
 | 
			
		||||
				}
 | 
			
		||||
				totalTokens := promptTokens + completionTokens
 | 
			
		||||
				if totalTokens == 0 {
 | 
			
		||||
					// in this case, must be some error happened
 | 
			
		||||
					// we cannot just return, because we may have to return the pre-consumed quota
 | 
			
		||||
					quota = 0
 | 
			
		||||
				}
 | 
			
		||||
				quotaDelta := quota - preConsumedQuota
 | 
			
		||||
				err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
				if quota != 0 {
 | 
			
		||||
					logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
					model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
					model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
					model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			totalTokens := promptTokens + completionTokens
 | 
			
		||||
			if totalTokens == 0 {
 | 
			
		||||
				// in this case, must be some error happened
 | 
			
		||||
				// we cannot just return, because we may have to return the pre-consumed quota
 | 
			
		||||
				quota = 0
 | 
			
		||||
			}
 | 
			
		||||
			quotaDelta := quota - preConsumedQuota
 | 
			
		||||
			err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error consuming token remain quota: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			err = model.CacheUpdateUserQuota(userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("error update user quota cache: " + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
			if quota != 0 {
 | 
			
		||||
				tokenName := c.GetString("token_name")
 | 
			
		||||
				logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
 | 
			
		||||
				model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
 | 
			
		||||
				model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
 | 
			
		||||
				channelId := c.GetInt("channel_id")
 | 
			
		||||
				model.UpdateChannelUsedQuota(channelId, quota)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		}()
 | 
			
		||||
	}()
 | 
			
		||||
	switch apiType {
 | 
			
		||||
	case APITypeOpenAI:
 | 
			
		||||
@@ -471,7 +522,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := aliHandler(c, resp)
 | 
			
		||||
			var err *OpenAIErrorWithStatusCode
 | 
			
		||||
			var usage *Usage
 | 
			
		||||
			switch relayMode {
 | 
			
		||||
			case RelayModeEmbeddings:
 | 
			
		||||
				err, usage = aliEmbeddingHandler(c, resp)
 | 
			
		||||
			default:
 | 
			
		||||
				err, usage = aliHandler(c, resp)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
@@ -499,6 +557,26 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
		} else {
 | 
			
		||||
			return errorWrapper(errors.New("xunfei api does not support non-stream mode"), "invalid_api_type", http.StatusBadRequest)
 | 
			
		||||
		}
 | 
			
		||||
	case APITypeAIProxyLibrary:
 | 
			
		||||
		if isStream {
 | 
			
		||||
			err, usage := aiProxyLibraryStreamHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			err, usage := aiProxyLibraryHandler(c, resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if usage != nil {
 | 
			
		||||
				textResponse.Usage = *usage
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,38 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/pkoukk/tiktoken-go"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"strconv"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var stopFinishReason = "stop"
 | 
			
		||||
 | 
			
		||||
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
 | 
			
		||||
 | 
			
		||||
func InitTokenEncoders() {
 | 
			
		||||
	common.SysLog("initializing token encoders")
 | 
			
		||||
	fallbackTokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.FatalLog(fmt.Sprintf("failed to get fallback token encoder: %s", err.Error()))
 | 
			
		||||
	}
 | 
			
		||||
	for model, _ := range common.ModelRatio {
 | 
			
		||||
		tokenEncoder, err := tiktoken.EncodingForModel(model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError(fmt.Sprintf("using fallback encoder for model %s", model))
 | 
			
		||||
			tokenEncoderMap[model] = fallbackTokenEncoder
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		tokenEncoderMap[model] = tokenEncoder
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("token encoders initialized")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenEncoder(model string) *tiktoken.Tiktoken {
 | 
			
		||||
	if tokenEncoder, ok := tokenEncoderMap[model]; ok {
 | 
			
		||||
		return tokenEncoder
 | 
			
		||||
@@ -94,15 +117,53 @@ func errorWrapper(err error, code string, statusCode int) *OpenAIErrorWithStatus
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func shouldDisableChannel(err *OpenAIError) bool {
 | 
			
		||||
func shouldDisableChannel(err *OpenAIError, statusCode int) bool {
 | 
			
		||||
	if !common.AutomaticDisableChannelEnabled {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if statusCode == http.StatusUnauthorized {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	if err.Type == "insufficient_quota" || err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func setEventStreamHeaders(c *gin.Context) {
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func relayErrorHandler(resp *http.Response) (openAIErrorWithStatusCode *OpenAIErrorWithStatusCode) {
 | 
			
		||||
	openAIErrorWithStatusCode = &OpenAIErrorWithStatusCode{
 | 
			
		||||
		StatusCode: resp.StatusCode,
 | 
			
		||||
		OpenAIError: OpenAIError{
 | 
			
		||||
			Message: fmt.Sprintf("bad response status code %d", resp.StatusCode),
 | 
			
		||||
			Type:    "one_api_error",
 | 
			
		||||
			Code:    "bad_response_status_code",
 | 
			
		||||
			Param:   strconv.Itoa(resp.StatusCode),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var textResponse TextResponse
 | 
			
		||||
	err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	openAIErrorWithStatusCode.OpenAIError = textResponse.Error
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -75,7 +75,7 @@ type XunfeiChatResponse struct {
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *XunfeiChatRequest {
 | 
			
		||||
func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, domain string) *XunfeiChatRequest {
 | 
			
		||||
	messages := make([]XunfeiMessage, 0, len(request.Messages))
 | 
			
		||||
	for _, message := range request.Messages {
 | 
			
		||||
		if message.Role == "system" {
 | 
			
		||||
@@ -96,7 +96,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string) *Xun
 | 
			
		||||
	}
 | 
			
		||||
	xunfeiRequest := XunfeiChatRequest{}
 | 
			
		||||
	xunfeiRequest.Header.AppId = xunfeiAppId
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.Domain = "general"
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.Domain = domain
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.TopK = request.N
 | 
			
		||||
	xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
 | 
			
		||||
@@ -178,15 +178,28 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string {
 | 
			
		||||
 | 
			
		||||
func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*OpenAIErrorWithStatusCode, *Usage) {
 | 
			
		||||
	var usage Usage
 | 
			
		||||
	query := c.Request.URL.Query()
 | 
			
		||||
	apiVersion := query.Get("api-version")
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = c.GetString("api_version")
 | 
			
		||||
	}
 | 
			
		||||
	if apiVersion == "" {
 | 
			
		||||
		apiVersion = "v1.1"
 | 
			
		||||
		common.SysLog("api_version not found, use default: " + apiVersion)
 | 
			
		||||
	}
 | 
			
		||||
	domain := "general"
 | 
			
		||||
	if apiVersion == "v2.1" {
 | 
			
		||||
		domain = "generalv2"
 | 
			
		||||
	}
 | 
			
		||||
	hostUrl := fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion)
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	hostUrl := "wss://aichat.xf-yun.com/v1/chat"
 | 
			
		||||
	conn, resp, err := d.Dial(buildXunfeiAuthUrl(hostUrl, apiKey, apiSecret), nil)
 | 
			
		||||
	if err != nil || resp.StatusCode != 101 {
 | 
			
		||||
		return errorWrapper(err, "dial_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	data := requestOpenAI2Xunfei(textRequest, appId)
 | 
			
		||||
	data := requestOpenAI2Xunfei(textRequest, appId, domain)
 | 
			
		||||
	err = conn.WriteJSON(data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return errorWrapper(err, "write_json_failed", http.StatusInternalServerError), nil
 | 
			
		||||
@@ -217,11 +230,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest GeneralOpenAIRequest, appId
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case xunfeiResponse := <-dataChan:
 | 
			
		||||
 
 | 
			
		||||
@@ -224,11 +224,7 @@ func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithSt
 | 
			
		||||
		}
 | 
			
		||||
		stopChan <- true
 | 
			
		||||
	}()
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	c.Writer.Header().Set("Cache-Control", "no-cache")
 | 
			
		||||
	c.Writer.Header().Set("Connection", "keep-alive")
 | 
			
		||||
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
 | 
			
		||||
	c.Writer.Header().Set("X-Accel-Buffering", "no")
 | 
			
		||||
	setEventStreamHeaders(c)
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		select {
 | 
			
		||||
		case data := <-dataChan:
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,7 @@ const (
 | 
			
		||||
	RelayModeModerations
 | 
			
		||||
	RelayModeImagesGenerations
 | 
			
		||||
	RelayModeEdits
 | 
			
		||||
	RelayModeAudio
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// https://platform.openai.com/docs/api-reference/chat
 | 
			
		||||
@@ -40,6 +41,26 @@ type GeneralOpenAIRequest struct {
 | 
			
		||||
	Input       any       `json:"input,omitempty"`
 | 
			
		||||
	Instruction string    `json:"instruction,omitempty"`
 | 
			
		||||
	Size        string    `json:"size,omitempty"`
 | 
			
		||||
	Functions   any       `json:"functions,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r GeneralOpenAIRequest) ParseInput() []string {
 | 
			
		||||
	if r.Input == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	var input []string
 | 
			
		||||
	switch r.Input.(type) {
 | 
			
		||||
	case string:
 | 
			
		||||
		input = []string{r.Input.(string)}
 | 
			
		||||
	case []any:
 | 
			
		||||
		input = make([]string, 0, len(r.Input.([]any)))
 | 
			
		||||
		for _, item := range r.Input.([]any) {
 | 
			
		||||
			if str, ok := item.(string); ok {
 | 
			
		||||
				input = append(input, str)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return input
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatRequest struct {
 | 
			
		||||
@@ -62,6 +83,10 @@ type ImageRequest struct {
 | 
			
		||||
	Size   string `json:"size"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AudioResponse struct {
 | 
			
		||||
	Text string `json:"text,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Usage struct {
 | 
			
		||||
	PromptTokens     int `json:"prompt_tokens"`
 | 
			
		||||
	CompletionTokens int `json:"completion_tokens"`
 | 
			
		||||
@@ -158,11 +183,15 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		relayMode = RelayModeImagesGenerations
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
 | 
			
		||||
		relayMode = RelayModeEdits
 | 
			
		||||
	} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
		relayMode = RelayModeAudio
 | 
			
		||||
	}
 | 
			
		||||
	var err *OpenAIErrorWithStatusCode
 | 
			
		||||
	switch relayMode {
 | 
			
		||||
	case RelayModeImagesGenerations:
 | 
			
		||||
		err = relayImageHelper(c, relayMode)
 | 
			
		||||
	case RelayModeAudio:
 | 
			
		||||
		err = relayAudioHelper(c, relayMode)
 | 
			
		||||
	default:
 | 
			
		||||
		err = relayTextHelper(c, relayMode)
 | 
			
		||||
	}
 | 
			
		||||
@@ -185,7 +214,7 @@ func Relay(c *gin.Context) {
 | 
			
		||||
		channelId := c.GetInt("channel_id")
 | 
			
		||||
		common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message))
 | 
			
		||||
		// https://platform.openai.com/docs/guides/error-codes/api-errors
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError) {
 | 
			
		||||
		if shouldDisableChannel(&err.OpenAIError, err.StatusCode) {
 | 
			
		||||
			channelId := c.GetInt("channel_id")
 | 
			
		||||
			channelName := c.GetString("channel_name")
 | 
			
		||||
			disableChannel(channelId, channelName, err.Message)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								go.mod
									
									
									
									
									
								
							@@ -18,7 +18,7 @@ require (
 | 
			
		||||
	golang.org/x/crypto v0.9.0
 | 
			
		||||
	gorm.io/driver/mysql v1.4.3
 | 
			
		||||
	gorm.io/driver/sqlite v1.4.3
 | 
			
		||||
	gorm.io/gorm v1.24.0
 | 
			
		||||
	gorm.io/gorm v1.25.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
@@ -36,6 +36,9 @@ require (
 | 
			
		||||
	github.com/gorilla/context v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/securecookie v1.1.1 // indirect
 | 
			
		||||
	github.com/gorilla/sessions v1.2.1 // indirect
 | 
			
		||||
	github.com/jackc/pgpassfile v1.0.0 // indirect
 | 
			
		||||
	github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
 | 
			
		||||
	github.com/jackc/pgx/v5 v5.3.1 // indirect
 | 
			
		||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
			
		||||
	github.com/jinzhu/now v1.1.5 // indirect
 | 
			
		||||
	github.com/json-iterator/go v1.1.12 // indirect
 | 
			
		||||
@@ -54,4 +57,5 @@ require (
 | 
			
		||||
	golang.org/x/text v0.9.0 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.30.0 // indirect
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1 // indirect
 | 
			
		||||
	gorm.io/driver/postgres v1.5.2 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								go.sum
									
									
									
									
									
								
							@@ -69,6 +69,12 @@ github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7Fsg
 | 
			
		||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
 | 
			
		||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
 | 
			
		||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
 | 
			
		||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
 | 
			
		||||
github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU=
 | 
			
		||||
github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8=
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
 | 
			
		||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
 | 
			
		||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
 | 
			
		||||
@@ -187,9 +193,13 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 | 
			
		||||
gorm.io/driver/mysql v1.4.3 h1:/JhWJhO2v17d8hjApTltKNADm7K7YI2ogkR7avJUL3k=
 | 
			
		||||
gorm.io/driver/mysql v1.4.3/go.mod h1:sSIebwZAVPiT+27jK9HIwvsqOGKx3YMPmrA3mBJR10c=
 | 
			
		||||
gorm.io/driver/postgres v1.5.2 h1:ytTDxxEv+MplXOfFe3Lzm7SjG09fcdb3Z/c056DTBx0=
 | 
			
		||||
gorm.io/driver/postgres v1.5.2/go.mod h1:fmpX0m2I1PKuR7mKZiEluwrP3hbs+ps7JIGMUBpCgl8=
 | 
			
		||||
gorm.io/driver/sqlite v1.4.3 h1:HBBcZSDnWi5BW3B3rwvVTc510KGkBkexlOg0QrmLUuU=
 | 
			
		||||
gorm.io/driver/sqlite v1.4.3/go.mod h1:0Aq3iPO+v9ZKbcdiz8gLWRw5VOPcBOPUQJFLq5e2ecI=
 | 
			
		||||
gorm.io/gorm v1.23.8/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
 | 
			
		||||
gorm.io/gorm v1.24.0 h1:j/CoiSm6xpRpmzbFJsQHYj+I8bGYWLXVHeYEyyKlF74=
 | 
			
		||||
gorm.io/gorm v1.24.0/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
 | 
			
		||||
gorm.io/gorm v1.25.0 h1:+KtYtb2roDz14EQe4bla8CbQlmb9dN3VejSai3lprfU=
 | 
			
		||||
gorm.io/gorm v1.25.0/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
 | 
			
		||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
 | 
			
		||||
 
 | 
			
		||||
@@ -519,5 +519,10 @@
 | 
			
		||||
  "令牌创建成功,请在列表页面点击复制获取令牌!": "Token created successfully, please click copy on the list page to get the token!",
 | 
			
		||||
  "代理": "Proxy",
 | 
			
		||||
  "此项可选,用于通过代理站来进行 API 调用,请输入代理站地址,格式为:https://domain.com": "This is optional, used to make API calls through the proxy site, please enter the proxy site address, the format is: https://domain.com",
 | 
			
		||||
  "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?"
 | 
			
		||||
  "取消密码登录将导致所有未绑定其他登录方式的用户(包括管理员)无法通过密码登录,确认取消?": "Canceling password login will cause all users (including administrators) who have not bound other login methods to be unable to log in via password, confirm cancel?",
 | 
			
		||||
  "按照如下格式输入:": "Enter in the following format:",
 | 
			
		||||
  "模型版本": "Model version",
 | 
			
		||||
  "请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1": "Please enter the version of the Starfire model, note that it is the version number in the interface address, for example: v2.1",
 | 
			
		||||
  "点击查看": "click to view",
 | 
			
		||||
  "请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!": "Please make sure that the gpt-35-turbo model has been created on Azure, and the apiVersion has been filled in correctly!"
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.go
									
									
									
									
									
								
							@@ -77,6 +77,12 @@ func main() {
 | 
			
		||||
		}
 | 
			
		||||
		go controller.AutomaticallyTestChannels(frequency)
 | 
			
		||||
	}
 | 
			
		||||
	if os.Getenv("BATCH_UPDATE_ENABLED") == "true" {
 | 
			
		||||
		common.BatchUpdateEnabled = true
 | 
			
		||||
		common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s")
 | 
			
		||||
		model.InitBatchUpdater()
 | 
			
		||||
	}
 | 
			
		||||
	controller.InitTokenEncoders()
 | 
			
		||||
 | 
			
		||||
	// Initialize HTTP server
 | 
			
		||||
	server := gin.Default()
 | 
			
		||||
 
 | 
			
		||||
@@ -100,7 +100,18 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !model.CacheIsUserEnabled(token.UserId) {
 | 
			
		||||
		userEnabled, err := model.IsUserEnabled(token.UserId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.JSON(http.StatusInternalServerError, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": err.Error(),
 | 
			
		||||
					"type":    "one_api_error",
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if !userEnabled {
 | 
			
		||||
			c.JSON(http.StatusForbidden, gin.H{
 | 
			
		||||
				"error": gin.H{
 | 
			
		||||
					"message": "用户已被封禁",
 | 
			
		||||
 
 | 
			
		||||
@@ -58,7 +58,10 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		} else {
 | 
			
		||||
			// Select a channel for the user
 | 
			
		||||
			var modelRequest ModelRequest
 | 
			
		||||
			err := common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			var err error
 | 
			
		||||
			if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
				err = common.UnmarshalBodyReusable(c, &modelRequest)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(http.StatusBadRequest, gin.H{
 | 
			
		||||
					"error": gin.H{
 | 
			
		||||
@@ -84,6 +87,11 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					modelRequest.Model = "dall-e"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
 | 
			
		||||
				if modelRequest.Model == "" {
 | 
			
		||||
					modelRequest.Model = "whisper-1"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
 | 
			
		||||
@@ -107,8 +115,13 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
		c.Set("model_mapping", channel.ModelMapping)
 | 
			
		||||
		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
 | 
			
		||||
		c.Set("base_url", channel.BaseURL)
 | 
			
		||||
		if channel.Type == common.ChannelTypeAzure {
 | 
			
		||||
		switch channel.Type {
 | 
			
		||||
		case common.ChannelTypeAzure:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeXunfei:
 | 
			
		||||
			c.Set("api_version", channel.Other)
 | 
			
		||||
		case common.ChannelTypeAIProxyLibrary:
 | 
			
		||||
			c.Set("library_id", channel.Other)
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -95,23 +95,36 @@ func CacheUpdateUserQuota(id int) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheIsUserEnabled(userId int) bool {
 | 
			
		||||
func CacheDecreaseUserQuota(id int, quota int) error {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	err := common.RedisDecrease(fmt.Sprintf("user_quota:%d", id), int64(quota))
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CacheIsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
	if !common.RedisEnabled {
 | 
			
		||||
		return IsUserEnabled(userId)
 | 
			
		||||
	}
 | 
			
		||||
	enabled, err := common.RedisGet(fmt.Sprintf("user_enabled:%d", userId))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		status := common.UserStatusDisabled
 | 
			
		||||
		if IsUserEnabled(userId) {
 | 
			
		||||
			status = common.UserStatusEnabled
 | 
			
		||||
		}
 | 
			
		||||
		enabled = fmt.Sprintf("%d", status)
 | 
			
		||||
		err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			common.SysError("Redis set user enabled error: " + err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		return enabled == "1", nil
 | 
			
		||||
	}
 | 
			
		||||
	return enabled == "1"
 | 
			
		||||
 | 
			
		||||
	userEnabled, err := IsUserEnabled(userId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	enabled = "0"
 | 
			
		||||
	if userEnabled {
 | 
			
		||||
		enabled = "1"
 | 
			
		||||
	}
 | 
			
		||||
	err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("Redis set user enabled error: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	return userEnabled, err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var group2model2channels map[string]map[string][]*Channel
 | 
			
		||||
 
 | 
			
		||||
@@ -141,7 +141,15 @@ func UpdateChannelStatusById(id int, status int) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
	err := DB.Set("gorm:query_option", "FOR UPDATE").Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	updateChannelUsedQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateChannelUsedQuota(id int, quota int) {
 | 
			
		||||
	err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("failed to update channel used quota: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -2,10 +2,12 @@ package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/driver/postgres"
 | 
			
		||||
	"gorm.io/driver/sqlite"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -34,28 +36,35 @@ func createRootAccountIfNeed() error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func CountTable(tableName string) (num int64) {
 | 
			
		||||
	DB.Table(tableName).Count(&num)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitDB() (err error) {
 | 
			
		||||
	var db *gorm.DB
 | 
			
		||||
func chooseDB() (*gorm.DB, error) {
 | 
			
		||||
	if os.Getenv("SQL_DSN") != "" {
 | 
			
		||||
		dsn := os.Getenv("SQL_DSN")
 | 
			
		||||
		if strings.HasPrefix(dsn, "postgres://") {
 | 
			
		||||
			// Use PostgreSQL
 | 
			
		||||
			common.SysLog("using PostgreSQL as database")
 | 
			
		||||
			return gorm.Open(postgres.New(postgres.Config{
 | 
			
		||||
				DSN:                  dsn,
 | 
			
		||||
				PreferSimpleProtocol: true, // disables implicit prepared statement usage
 | 
			
		||||
			}), &gorm.Config{
 | 
			
		||||
				PrepareStmt: true, // precompile SQL
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
		// Use MySQL
 | 
			
		||||
		common.SysLog("using MySQL as database")
 | 
			
		||||
		db, err = gorm.Open(mysql.Open(os.Getenv("SQL_DSN")), &gorm.Config{
 | 
			
		||||
			PrepareStmt: true, // precompile SQL
 | 
			
		||||
		})
 | 
			
		||||
	} else {
 | 
			
		||||
		// Use SQLite
 | 
			
		||||
		common.SysLog("SQL_DSN not set, using SQLite as database")
 | 
			
		||||
		common.UsingSQLite = true
 | 
			
		||||
		db, err = gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 | 
			
		||||
		return gorm.Open(mysql.Open(dsn), &gorm.Config{
 | 
			
		||||
			PrepareStmt: true, // precompile SQL
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("database connected")
 | 
			
		||||
	// Use SQLite
 | 
			
		||||
	common.SysLog("SQL_DSN not set, using SQLite as database")
 | 
			
		||||
	common.UsingSQLite = true
 | 
			
		||||
	return gorm.Open(sqlite.Open(common.SQLitePath), &gorm.Config{
 | 
			
		||||
		PrepareStmt: true, // precompile SQL
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitDB() (err error) {
 | 
			
		||||
	db, err := chooseDB()
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if common.DebugEnabled {
 | 
			
		||||
			db = db.Debug()
 | 
			
		||||
@@ -65,8 +74,8 @@ func InitDB() (err error) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 10))
 | 
			
		||||
		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 100))
 | 
			
		||||
		sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100))
 | 
			
		||||
		sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000))
 | 
			
		||||
		sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60)))
 | 
			
		||||
 | 
			
		||||
		if !common.IsMasterNode {
 | 
			
		||||
 
 | 
			
		||||
@@ -39,32 +39,35 @@ func ValidateUserToken(key string) (token *Token, err error) {
 | 
			
		||||
	}
 | 
			
		||||
	token, err = CacheGetTokenByKey(key)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		if token.Status == common.TokenStatusExhausted {
 | 
			
		||||
			return nil, errors.New("该令牌额度已用尽")
 | 
			
		||||
		} else if token.Status == common.TokenStatusExpired {
 | 
			
		||||
			return nil, errors.New("该令牌已过期")
 | 
			
		||||
		}
 | 
			
		||||
		if token.Status != common.TokenStatusEnabled {
 | 
			
		||||
			return nil, errors.New("该令牌状态不可用")
 | 
			
		||||
		}
 | 
			
		||||
		if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() {
 | 
			
		||||
			token.Status = common.TokenStatusExpired
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
			if !common.RedisEnabled {
 | 
			
		||||
				token.Status = common.TokenStatusExpired
 | 
			
		||||
				err := token.SelectUpdate()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return nil, errors.New("该令牌已过期")
 | 
			
		||||
		}
 | 
			
		||||
		if !token.UnlimitedQuota && token.RemainQuota <= 0 {
 | 
			
		||||
			token.Status = common.TokenStatusExhausted
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
			if !common.RedisEnabled {
 | 
			
		||||
				// in this case, we can make sure the token is exhausted
 | 
			
		||||
				token.Status = common.TokenStatusExhausted
 | 
			
		||||
				err := token.SelectUpdate()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to update token status" + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return nil, errors.New("该令牌额度已用尽")
 | 
			
		||||
		}
 | 
			
		||||
		go func() {
 | 
			
		||||
			token.AccessedTime = common.GetTimestamp()
 | 
			
		||||
			err := token.SelectUpdate()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				common.SysError("failed to update token" + err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		return token, nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errors.New("无效的令牌")
 | 
			
		||||
@@ -131,10 +134,19 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeTokenQuota, id, quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return increaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota": gorm.Expr("remain_quota + ?", quota),
 | 
			
		||||
			"used_quota":   gorm.Expr("used_quota - ?", quota),
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota + ?", quota),
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota - ?", quota),
 | 
			
		||||
			"accessed_time": common.GetTimestamp(),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	return err
 | 
			
		||||
@@ -144,10 +156,19 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeTokenQuota, id, -quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return decreaseTokenQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseTokenQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&Token{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"remain_quota": gorm.Expr("remain_quota - ?", quota),
 | 
			
		||||
			"used_quota":   gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"remain_quota":  gorm.Expr("remain_quota - ?", quota),
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"accessed_time": common.GetTimestamp(),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	return err
 | 
			
		||||
 
 | 
			
		||||
@@ -226,17 +226,16 @@ func IsAdmin(userId int) bool {
 | 
			
		||||
	return user.Role >= common.RoleAdminUser
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsUserEnabled(userId int) bool {
 | 
			
		||||
func IsUserEnabled(userId int) (bool, error) {
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		return false
 | 
			
		||||
		return false, errors.New("user id is empty")
 | 
			
		||||
	}
 | 
			
		||||
	var user User
 | 
			
		||||
	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		common.SysError("no such user " + err.Error())
 | 
			
		||||
		return false
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return user.Status == common.UserStatusEnabled
 | 
			
		||||
	return user.Status == common.UserStatusEnabled, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ValidateAccessToken(token string) (user *User) {
 | 
			
		||||
@@ -275,7 +274,15 @@ func IncreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return increaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func increaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -283,7 +290,15 @@ func DecreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	if quota < 0 {
 | 
			
		||||
		return errors.New("quota 不能为负数!")
 | 
			
		||||
	}
 | 
			
		||||
	err = DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	return decreaseUserQuota(id, quota)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func decreaseUserQuota(id int, quota int) (err error) {
 | 
			
		||||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -293,10 +308,18 @@ func GetRootUserEmail() (email string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
 | 
			
		||||
	err := DB.Set("gorm:query_option", "FOR UPDATE").Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
	if common.BatchUpdateEnabled {
 | 
			
		||||
		addNewRecord(BatchUpdateTypeUsedQuotaAndRequestCount, id, quota)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
 | 
			
		||||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
			
		||||
		map[string]interface{}{
 | 
			
		||||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
			
		||||
			"request_count": gorm.Expr("request_count + ?", 1),
 | 
			
		||||
			"request_count": gorm.Expr("request_count + ?", count),
 | 
			
		||||
		},
 | 
			
		||||
	).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								model/utils.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,75 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const BatchUpdateTypeCount = 4 // if you add a new type, you need to add a new map and a new lock
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	BatchUpdateTypeUserQuota = iota
 | 
			
		||||
	BatchUpdateTypeTokenQuota
 | 
			
		||||
	BatchUpdateTypeUsedQuotaAndRequestCount
 | 
			
		||||
	BatchUpdateTypeChannelUsedQuota
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var batchUpdateStores []map[int]int
 | 
			
		||||
var batchUpdateLocks []sync.Mutex
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateStores = append(batchUpdateStores, make(map[int]int))
 | 
			
		||||
		batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitBatchUpdater() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second)
 | 
			
		||||
			batchUpdate()
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func addNewRecord(type_ int, id int, value int) {
 | 
			
		||||
	batchUpdateLocks[type_].Lock()
 | 
			
		||||
	defer batchUpdateLocks[type_].Unlock()
 | 
			
		||||
	if _, ok := batchUpdateStores[type_][id]; !ok {
 | 
			
		||||
		batchUpdateStores[type_][id] = value
 | 
			
		||||
	} else {
 | 
			
		||||
		batchUpdateStores[type_][id] += value
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func batchUpdate() {
 | 
			
		||||
	common.SysLog("batch update started")
 | 
			
		||||
	for i := 0; i < BatchUpdateTypeCount; i++ {
 | 
			
		||||
		batchUpdateLocks[i].Lock()
 | 
			
		||||
		store := batchUpdateStores[i]
 | 
			
		||||
		batchUpdateStores[i] = make(map[int]int)
 | 
			
		||||
		batchUpdateLocks[i].Unlock()
 | 
			
		||||
 | 
			
		||||
		for key, value := range store {
 | 
			
		||||
			switch i {
 | 
			
		||||
			case BatchUpdateTypeUserQuota:
 | 
			
		||||
				err := increaseUserQuota(key, value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to batch update user quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			case BatchUpdateTypeTokenQuota:
 | 
			
		||||
				err := increaseTokenQuota(key, value)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					common.SysError("failed to batch update token quota: " + err.Error())
 | 
			
		||||
				}
 | 
			
		||||
			case BatchUpdateTypeUsedQuotaAndRequestCount:
 | 
			
		||||
				updateUserUsedQuotaAndRequestCount(key, value, 1) // TODO: count is incorrect
 | 
			
		||||
			case BatchUpdateTypeChannelUsedQuota:
 | 
			
		||||
				updateChannelUsedQuota(key, value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	common.SysLog("batch update finished")
 | 
			
		||||
}
 | 
			
		||||
@@ -26,8 +26,8 @@ func SetRelayRouter(router *gin.Engine) {
 | 
			
		||||
		relayV1Router.POST("/images/variations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/engines/:model/embeddings", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/audio/transcriptions", controller.Relay)
 | 
			
		||||
		relayV1Router.POST("/audio/translations", controller.Relay)
 | 
			
		||||
		relayV1Router.GET("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.POST("/files", controller.RelayNotImplemented)
 | 
			
		||||
		relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
 | 
			
		||||
 
 | 
			
		||||
@@ -368,7 +368,7 @@ const ChannelsTable = () => {
 | 
			
		||||
                      }} style={{ cursor: 'pointer' }}>
 | 
			
		||||
                      {renderBalance(channel.type, channel.balance)}
 | 
			
		||||
                    </span>}
 | 
			
		||||
                      content="点击更新"
 | 
			
		||||
                      content='点击更新'
 | 
			
		||||
                      basic
 | 
			
		||||
                    />
 | 
			
		||||
                  </Table.Cell>
 | 
			
		||||
@@ -447,8 +447,8 @@ const ChannelsTable = () => {
 | 
			
		||||
              <Button size='small' loading={loading} onClick={testAllChannels}>
 | 
			
		||||
                测试所有已启用通道
 | 
			
		||||
              </Button>
 | 
			
		||||
              {/* <Button size='small' onClick={updateAllChannelsBalance}
 | 
			
		||||
                      loading={loading || updatingBalance}>更新所有已启用通道余额</Button> */}
 | 
			
		||||
              <Button size='small' onClick={updateAllChannelsBalance}
 | 
			
		||||
                      loading={loading || updatingBalance}>更新所有已启用通道余额</Button>
 | 
			
		||||
              <Pagination
 | 
			
		||||
                floated='right'
 | 
			
		||||
                activePage={activePage}
 | 
			
		||||
 
 | 
			
		||||
@@ -43,6 +43,7 @@ function renderType(type) {
 | 
			
		||||
 | 
			
		||||
const LogsTable = () => {
 | 
			
		||||
  const [logs, setLogs] = useState([]);
 | 
			
		||||
  const [showStat, setShowStat] = useState(false);
 | 
			
		||||
  const [loading, setLoading] = useState(true);
 | 
			
		||||
  const [activePage, setActivePage] = useState(1);
 | 
			
		||||
  const [searchKeyword, setSearchKeyword] = useState('');
 | 
			
		||||
@@ -92,6 +93,17 @@ const LogsTable = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const handleEyeClick = async () => {
 | 
			
		||||
    if (!showStat) {
 | 
			
		||||
      if (isAdminUser) {
 | 
			
		||||
        await getLogStat();
 | 
			
		||||
      } else {
 | 
			
		||||
        await getLogSelfStat();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    setShowStat(!showStat);
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const loadLogs = async (startIdx) => {
 | 
			
		||||
    let url = '';
 | 
			
		||||
    let localStartTimestamp = Date.parse(start_timestamp) / 1000;
 | 
			
		||||
@@ -129,13 +141,8 @@ const LogsTable = () => {
 | 
			
		||||
 | 
			
		||||
  const refresh = async () => {
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    setActivePage(1)
 | 
			
		||||
    setActivePage(1);
 | 
			
		||||
    await loadLogs(0);
 | 
			
		||||
    if (isAdminUser) {
 | 
			
		||||
      getLogStat().then();
 | 
			
		||||
    } else {
 | 
			
		||||
      getLogSelfStat().then();
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
@@ -169,7 +176,7 @@ const LogsTable = () => {
 | 
			
		||||
    if (logs.length === 0) return;
 | 
			
		||||
    setLoading(true);
 | 
			
		||||
    let sortedLogs = [...logs];
 | 
			
		||||
    if (typeof sortedLogs[0][key] === 'string'){
 | 
			
		||||
    if (typeof sortedLogs[0][key] === 'string') {
 | 
			
		||||
      sortedLogs.sort((a, b) => {
 | 
			
		||||
        return ('' + a[key]).localeCompare(b[key]);
 | 
			
		||||
      });
 | 
			
		||||
@@ -190,7 +197,12 @@ const LogsTable = () => {
 | 
			
		||||
  return (
 | 
			
		||||
    <>
 | 
			
		||||
      <Segment>
 | 
			
		||||
        <Header as='h3'>使用明细(总消耗额度:{renderQuota(stat.quota)})</Header>
 | 
			
		||||
        <Header as='h3'>
 | 
			
		||||
          使用明细(总消耗额度:
 | 
			
		||||
          {showStat && renderQuota(stat.quota)}
 | 
			
		||||
          {!showStat && <span onClick={handleEyeClick} style={{ cursor: 'pointer', color: 'gray' }}>点击查看</span>}
 | 
			
		||||
          )
 | 
			
		||||
        </Header>
 | 
			
		||||
        <Form>
 | 
			
		||||
          <Form.Group>
 | 
			
		||||
            {
 | 
			
		||||
@@ -312,7 +324,7 @@ const LogsTable = () => {
 | 
			
		||||
              .map((log, idx) => {
 | 
			
		||||
                if (log.deleted) return <></>;
 | 
			
		||||
                return (
 | 
			
		||||
                  <Table.Row key={log.created_at}>
 | 
			
		||||
                  <Table.Row key={log.id}>
 | 
			
		||||
                    <Table.Cell>{renderTimestamp(log.created_at)}</Table.Cell>
 | 
			
		||||
                    {
 | 
			
		||||
                      isAdminUser && (
 | 
			
		||||
 
 | 
			
		||||
@@ -7,7 +7,11 @@ export const CHANNEL_OPTIONS = [
 | 
			
		||||
  { key: 17, text: '阿里通义千问', value: 17, color: 'orange' },
 | 
			
		||||
  { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' },
 | 
			
		||||
  { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' },
 | 
			
		||||
  { key: 19, text: '360 智脑', value: 19, color: 'blue' },
 | 
			
		||||
  { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
 | 
			
		||||
  { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' },
 | 
			
		||||
  { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },
 | 
			
		||||
  { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' },
 | 
			
		||||
  { key: 2, text: '代理:API2D', value: 2, color: 'blue' },
 | 
			
		||||
  { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' },
 | 
			
		||||
  { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' },
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import React, { useEffect, useState } from 'react';
 | 
			
		||||
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
 | 
			
		||||
import { useParams, useNavigate } from 'react-router-dom';
 | 
			
		||||
import { useNavigate, useParams } from 'react-router-dom';
 | 
			
		||||
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
 | 
			
		||||
import { CHANNEL_OPTIONS } from '../../constants';
 | 
			
		||||
 | 
			
		||||
@@ -10,6 +10,20 @@ const MODEL_MAPPING_EXAMPLE = {
 | 
			
		||||
  'gpt-4-32k-0314': 'gpt-4-32k'
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
function type2secretPrompt(type) {
 | 
			
		||||
  // inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
 | 
			
		||||
  switch (type) {
 | 
			
		||||
    case 15:
 | 
			
		||||
      return '按照如下格式输入:APIKey|SecretKey';
 | 
			
		||||
    case 18:
 | 
			
		||||
      return '按照如下格式输入:APPID|APISecret|APIKey';
 | 
			
		||||
    case 22:
 | 
			
		||||
      return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
 | 
			
		||||
    default:
 | 
			
		||||
      return '请输入渠道对应的鉴权密钥';
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const EditChannel = () => {
 | 
			
		||||
  const params = useParams();
 | 
			
		||||
  const navigate = useNavigate();
 | 
			
		||||
@@ -19,7 +33,7 @@ const EditChannel = () => {
 | 
			
		||||
  const handleCancel = () => {
 | 
			
		||||
    navigate('/channel');
 | 
			
		||||
  };
 | 
			
		||||
  
 | 
			
		||||
 | 
			
		||||
  const originInputs = {
 | 
			
		||||
    name: '',
 | 
			
		||||
    type: 1,
 | 
			
		||||
@@ -53,7 +67,7 @@ const EditChannel = () => {
 | 
			
		||||
          localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'Embedding-V1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 17:
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1'];
 | 
			
		||||
          localModels = ['qwen-v1', 'qwen-plus-v1', 'text-embedding-v1'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 16:
 | 
			
		||||
          localModels = ['chatglm_pro', 'chatglm_std', 'chatglm_lite'];
 | 
			
		||||
@@ -61,6 +75,9 @@ const EditChannel = () => {
 | 
			
		||||
        case 18:
 | 
			
		||||
          localModels = ['SparkDesk'];
 | 
			
		||||
          break;
 | 
			
		||||
        case 19:
 | 
			
		||||
          localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1', '360GPT_S2_V9.4'];
 | 
			
		||||
          break;
 | 
			
		||||
      }
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, models: localModels }));
 | 
			
		||||
    }
 | 
			
		||||
@@ -163,6 +180,9 @@ const EditChannel = () => {
 | 
			
		||||
    if (localInputs.type === 3 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = '2023-06-01-preview';
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.type === 18 && localInputs.other === '') {
 | 
			
		||||
      localInputs.other = 'v2.1';
 | 
			
		||||
    }
 | 
			
		||||
    if (localInputs.model_mapping === '') {
 | 
			
		||||
      localInputs.model_mapping = '{}';
 | 
			
		||||
    }
 | 
			
		||||
@@ -187,6 +207,24 @@ const EditChannel = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const addCustomModel = () => {
 | 
			
		||||
    if (customModel.trim() === '') return;
 | 
			
		||||
    if (inputs.models.includes(customModel)) return;
 | 
			
		||||
    let localModels = [...inputs.models];
 | 
			
		||||
    localModels.push(customModel);
 | 
			
		||||
    let localModelOptions = [];
 | 
			
		||||
    localModelOptions.push({
 | 
			
		||||
      key: customModel,
 | 
			
		||||
      text: customModel,
 | 
			
		||||
      value: customModel
 | 
			
		||||
    });
 | 
			
		||||
    setModelOptions(modelOptions => {
 | 
			
		||||
      return [...modelOptions, ...localModelOptions];
 | 
			
		||||
    });
 | 
			
		||||
    setCustomModel('');
 | 
			
		||||
    handleInputChange(null, { name: 'models', value: localModels });
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  return (
 | 
			
		||||
    <>
 | 
			
		||||
      <Segment loading={loading}>
 | 
			
		||||
@@ -275,6 +313,34 @@ const EditChannel = () => {
 | 
			
		||||
              options={groupOptions}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 18 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='模型版本'
 | 
			
		||||
                  name='other'
 | 
			
		||||
                  placeholder={'请输入星火大模型版本,注意是接口地址中的版本号,例如:v2.1'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.other}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 21 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='知识库 ID'
 | 
			
		||||
                  name='other'
 | 
			
		||||
                  placeholder={'请输入知识库 ID,例如:123456'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.other}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Dropdown
 | 
			
		||||
              label='模型'
 | 
			
		||||
@@ -302,29 +368,19 @@ const EditChannel = () => {
 | 
			
		||||
            }}>清除所有模型</Button>
 | 
			
		||||
            <Input
 | 
			
		||||
              action={
 | 
			
		||||
                <Button type={'button'} onClick={() => {
 | 
			
		||||
                  if (customModel.trim() === '') return;
 | 
			
		||||
                  if (inputs.models.includes(customModel)) return;
 | 
			
		||||
                  let localModels = [...inputs.models];
 | 
			
		||||
                  localModels.push(customModel);
 | 
			
		||||
                  let localModelOptions = [];
 | 
			
		||||
                  localModelOptions.push({
 | 
			
		||||
                    key: customModel,
 | 
			
		||||
                    text: customModel,
 | 
			
		||||
                    value: customModel
 | 
			
		||||
                  });
 | 
			
		||||
                  setModelOptions(modelOptions => {
 | 
			
		||||
                    return [...modelOptions, ...localModelOptions];
 | 
			
		||||
                  });
 | 
			
		||||
                  setCustomModel('');
 | 
			
		||||
                  handleInputChange(null, { name: 'models', value: localModels });
 | 
			
		||||
                }}>填入</Button>
 | 
			
		||||
                <Button type={'button'} onClick={addCustomModel}>填入</Button>
 | 
			
		||||
              }
 | 
			
		||||
              placeholder='输入自定义模型名称'
 | 
			
		||||
              value={customModel}
 | 
			
		||||
              onChange={(e, { value }) => {
 | 
			
		||||
                setCustomModel(value);
 | 
			
		||||
              }}
 | 
			
		||||
              onKeyDown={(e) => {
 | 
			
		||||
                if (e.key === 'Enter') {
 | 
			
		||||
                  addCustomModel();
 | 
			
		||||
                  e.preventDefault();
 | 
			
		||||
                }
 | 
			
		||||
              }}
 | 
			
		||||
            />
 | 
			
		||||
          </div>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
@@ -355,7 +411,7 @@ const EditChannel = () => {
 | 
			
		||||
                label='密钥'
 | 
			
		||||
                name='key'
 | 
			
		||||
                required
 | 
			
		||||
                placeholder={inputs.type === 15 ? '请输入 access token,当前版本暂不支持自动刷新,请每 30 天更新一次' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')}
 | 
			
		||||
                placeholder={type2secretPrompt(inputs.type)}
 | 
			
		||||
                onChange={handleInputChange}
 | 
			
		||||
                value={inputs.key}
 | 
			
		||||
                autoComplete='new-password'
 | 
			
		||||
@@ -373,7 +429,7 @@ const EditChannel = () => {
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type !== 3 && inputs.type !== 8 && (
 | 
			
		||||
            inputs.type !== 3 && inputs.type !== 8 && inputs.type !== 22 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='代理'
 | 
			
		||||
@@ -386,6 +442,20 @@ const EditChannel = () => {
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          {
 | 
			
		||||
            inputs.type === 22 && (
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                  label='私有部署地址'
 | 
			
		||||
                  name='base_url'
 | 
			
		||||
                  placeholder={'请输入私有部署地址,格式为:https://fastgpt.run/api/openapi'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={inputs.base_url}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
            )
 | 
			
		||||
          }
 | 
			
		||||
          <Button onClick={handleCancel}>取消</Button>
 | 
			
		||||
          <Button type={isEdit ? 'button' : 'submit'} positive onClick={submit}>提交</Button>
 | 
			
		||||
        </Form>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user